mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0394fae446 | |||
| 602b8e66a6 | |||
| ab4dce6fed | |||
| 40f4386e4a | |||
| 87a91b4b08 | |||
| fadb900c36 | |||
| de0663226a | |||
| 0ca9d66cae | |||
| 2222f25da3 | |||
| acae8417aa | |||
| 2697f65cf6 | |||
| 74f42f218e | |||
| ca9d49e305 | |||
| 6705876d47 | |||
| aadbd27675 | |||
| 5221647b5e | |||
| 9c981300dd | |||
| f5b27aad1b | |||
| 75f1285507 | |||
| 33cedc2f71 | |||
| aa32e6c4ab | |||
| f906270ec4 | |||
| 733b6d84db | |||
| 8abc9037a3 | |||
| e4d4ac0bda | |||
| e79b2a439b | |||
| f9ae78ca74 | |||
| e1ced538e3 | |||
| 2a98602ad6 | |||
| a2f5b3571e | |||
| cecf2eff4f | |||
| 7e6b598a51 | |||
| 4fa41ba806 | |||
| 1de2b87a92 | |||
| e3c511db67 | |||
| aed4130d39 | |||
| d26349c692 | |||
| a9bce4732b | |||
| 86d69e3c1d |
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -185,7 +185,7 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
|
||||
@@ -224,7 +224,7 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
@@ -241,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,7 +270,7 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ lerobot-train \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ lerobot-train \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -266,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
|
||||
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
@@ -11,7 +12,6 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -40,9 +40,8 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -84,26 +83,24 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
stats = algorithm.update(batch_iter())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -147,15 +144,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
|
||||
+3
-3
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -211,15 +211,3 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm name registered in RLAlgorithmConfig registry
|
||||
algorithm: str = "sac"
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline"
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer
|
||||
online_ratio: float = 0.5
|
||||
|
||||
# RL trainer iterator
|
||||
async_prefetch: bool = True
|
||||
queue_size: int = 2
|
||||
|
||||
@@ -37,7 +37,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
@@ -1522,6 +1522,122 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
delta_action: bool = False,
|
||||
delta_exclude_joints: list[str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
delta_action: If True, compute action stats as delta (action - state).
|
||||
Useful when training with use_delta_actions=True so normalization matches.
|
||||
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||
action feature names to build the mask. Default: ["gripper"].
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
features = dataset.meta.features
|
||||
numeric_features = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] != "string"
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
if delta_exclude_joints is None:
|
||||
delta_exclude_joints = ["gripper"]
|
||||
action_names = features["action"].get("names")
|
||||
if action_names is not None:
|
||||
exclude = set(delta_exclude_joints)
|
||||
delta_mask = [n not in exclude for n in action_names]
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||
features_to_compute = {"action": features["action"]}
|
||||
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||
else:
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
# Also need state for delta computation even though we don't recompute state stats
|
||||
needs_state = delta_mask is not None
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
# Load state for delta even if we're not computing state stats
|
||||
if needs_state and "observation.state" in ep_df.columns:
|
||||
state_values = ep_df["observation.state"].values
|
||||
if hasattr(state_values[0], "__len__"):
|
||||
states = np.stack(state_values)
|
||||
else:
|
||||
states = np.array(state_values)
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(states).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
|
||||
@@ -122,9 +122,19 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -470,6 +470,13 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names for delta_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
delta_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
delta_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||
)
|
||||
|
||||
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
continuous_actions = to_absolute_actions(
|
||||
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.policies.rlt.configuration_rlt import RLTConfig
|
||||
from lerobot.policies.rlt.modeling_rlt import RLTPolicy
|
||||
|
||||
__all__ = ["RLTConfig", "RLTPolicy"]
|
||||
@@ -1,156 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy configuration.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTokenConfig:
|
||||
"""Configuration for the RL-token encoder/decoder transformer."""
|
||||
|
||||
input_dim: int = 2048
|
||||
rl_token_dim: int = 2048
|
||||
num_encoder_layers: int = 2
|
||||
num_decoder_layers: int = 2
|
||||
num_heads: int = 8
|
||||
ff_dim: int = 2048
|
||||
dropout: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTActorConfig:
|
||||
"""Configuration for the lightweight RL actor MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
std: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTCriticConfig:
|
||||
"""Configuration for the RLT critic MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTConfig(PreTrainedConfig):
|
||||
"""Configuration for the RLT (RL Token) policy.
|
||||
|
||||
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
|
||||
a lightweight actor-critic head using the RL token as state representation.
|
||||
The frozen VLA also provides reference action chunks that the actor refines.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {"min": [0.0], "max": [1.0]},
|
||||
ACTION: {"min": [0.0], "max": [1.0]},
|
||||
}
|
||||
)
|
||||
|
||||
# ── Device ──
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
|
||||
# ── VLA backbone ──
|
||||
vla_checkpoint: str | None = None
|
||||
|
||||
# ── RL-token ──
|
||||
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
|
||||
|
||||
# ── Actor / Critic heads ──
|
||||
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
|
||||
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
vla_chunk_size: int = 50
|
||||
|
||||
# ── Training parameters ──
|
||||
online_steps: int = 50000
|
||||
offline_steps: int = 5000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 500
|
||||
warmup_steps: int = 500
|
||||
async_prefetch: bool = False
|
||||
|
||||
# ── Algorithm hyperparameters ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
discount: float = 0.99
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
tau: float = 0.005
|
||||
clip_grad_norm: float = 10.0
|
||||
num_critics: int = 2
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
chunk_stride: int = 2
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
# ── Distributed ──
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
def get_optimizer_preset(self):
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,318 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy networks.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
|
||||
Architecture:
|
||||
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
|
||||
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
|
||||
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
|
||||
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
|
||||
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlt.configuration_rlt import RLTConfig
|
||||
|
||||
# ── Building blocks ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Simple feedforward network with ReLU activations."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
prev = input_dim
|
||||
for h in hidden_dims:
|
||||
layers.append(nn.Linear(prev, h))
|
||||
layers.append(nn.ReLU())
|
||||
prev = h
|
||||
layers.append(nn.Linear(prev, output_dim))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# ── RL Token Encoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenEncoder(nn.Module):
|
||||
"""Compress VLA token embeddings into a single RL token via a small transformer.
|
||||
|
||||
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
|
||||
through transformer encoder layers, and returns the output at the ``e_rl``
|
||||
position as the RL token ``z_rl``.
|
||||
|
||||
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
rl_token_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.rl_token_dim = rl_token_dim
|
||||
|
||||
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
|
||||
|
||||
if input_dim != rl_token_dim:
|
||||
self.input_proj = nn.Linear(input_dim, rl_token_dim)
|
||||
else:
|
||||
self.input_proj = nn.Identity()
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=rl_token_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
def forward(self, z_vla: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_vla: VLA token embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
|
||||
"""
|
||||
batch_size = z_vla.shape[0]
|
||||
e_rl = self.e_rl.expand(batch_size, -1, -1)
|
||||
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
|
||||
seq = self.input_proj(seq)
|
||||
out = self.transformer(seq)
|
||||
z_rl = out[:, -1, :] # output at e_rl position
|
||||
return z_rl
|
||||
|
||||
|
||||
# ── RL Token Decoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenDecoder(nn.Module):
|
||||
"""Autoregressively reconstruct VLA embeddings from z_rl.
|
||||
|
||||
Used only during Stage 1 (offline RL-token training).
|
||||
|
||||
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rl_token_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
|
||||
if rl_token_dim != output_dim:
|
||||
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
|
||||
else:
|
||||
self.rl_proj = nn.Identity()
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=output_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
||||
self.output_head = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_rl: RL token, shape ``(B, D_rl)``.
|
||||
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
Reconstructed embeddings, shape ``(B, M, D)``.
|
||||
"""
|
||||
seq_len = z_vla_stopped.shape[1]
|
||||
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
|
||||
|
||||
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
|
||||
|
||||
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
|
||||
|
||||
decoded = self.transformer(
|
||||
tgt=target,
|
||||
memory=z_rl_proj,
|
||||
tgt_mask=causal_mask,
|
||||
)
|
||||
return self.output_head(decoded) # (B, M, D)
|
||||
|
||||
|
||||
# ── Actor ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTActor(nn.Module):
|
||||
"""Lightweight actor that refines VLA reference action chunks.
|
||||
|
||||
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
|
||||
|
||||
The actor is conditioned on both the RL state and the VLA's proposed action
|
||||
chunk, acting as a "VLA-guided action editor".
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
|
||||
super().__init__()
|
||||
input_dim = state_dim + action_chunk_dim
|
||||
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
|
||||
self.log_std = math.log(std)
|
||||
|
||||
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
|
||||
"""Return the mean action chunk.
|
||||
|
||||
Args:
|
||||
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
|
||||
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
|
||||
|
||||
Returns:
|
||||
Refined action chunk (mean), shape ``(B, C*d)``.
|
||||
"""
|
||||
x = torch.cat([state, ref_action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Sample an action and return (action, log_prob)."""
|
||||
mean = self.forward(state, ref_action_chunk)
|
||||
std = math.exp(self.log_std)
|
||||
noise = torch.randn_like(mean) * std
|
||||
action = mean + noise
|
||||
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
|
||||
std * math.sqrt(2 * math.pi)
|
||||
)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
# ── Policy (inference bundle) ────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTPolicy(PreTrainedPolicy):
|
||||
"""RLT policy — bundles the RL-token encoder and actor for inference.
|
||||
|
||||
The frozen VLA backbone is **not** part of this module; it is loaded
|
||||
separately and its embeddings / reference actions are passed in via the
|
||||
observation dict (populated by the actor process or a preprocessor).
|
||||
|
||||
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
|
||||
and optimizers. This class only contains what is needed for ``select_action``.
|
||||
"""
|
||||
|
||||
name = "rlt"
|
||||
config_class = RLTConfig
|
||||
|
||||
def __init__(self, config: RLTConfig, dataset_stats=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
action_dim = config.output_features["action"].shape[0]
|
||||
action_chunk_dim = config.chunk_size * action_dim
|
||||
prop_feature = config.input_features.get("observation.state", None)
|
||||
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
|
||||
|
||||
state_dim = config.rl_token.rl_token_dim + proprioception_dim
|
||||
|
||||
# RL-token encoder (frozen after Stage 1)
|
||||
self.rl_token_encoder = RLTokenEncoder(
|
||||
input_dim=config.rl_token.input_dim,
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
num_layers=config.rl_token.num_encoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# RL-token decoder (used only during Stage 1 training)
|
||||
self.rl_token_decoder = RLTokenDecoder(
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
output_dim=config.rl_token.input_dim,
|
||||
num_layers=config.rl_token.num_decoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# Actor MLP
|
||||
self.actor = RLTActor(
|
||||
state_dim=state_dim,
|
||||
action_chunk_dim=action_chunk_dim,
|
||||
hidden_dims=config.actor.hidden_dims,
|
||||
std=config.actor.std,
|
||||
)
|
||||
|
||||
self._action_dim = action_dim
|
||||
self._action_chunk_dim = action_chunk_dim
|
||||
self._state_dim = state_dim
|
||||
self._proprioception_dim = proprioception_dim
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a refined action chunk given an observation.
|
||||
|
||||
Expects the observation dict to contain:
|
||||
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
|
||||
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
|
||||
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
|
||||
|
||||
Returns:
|
||||
Action chunk tensor of shape ``(C*d,)``.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
vla_emb = batch["observation.vla_embeddings"]
|
||||
if vla_emb.dim() == 2:
|
||||
vla_emb = vla_emb.unsqueeze(0)
|
||||
|
||||
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
|
||||
|
||||
parts = [z_rl]
|
||||
if "observation.state" in batch and self._proprioception_dim > 0:
|
||||
prop = batch["observation.state"]
|
||||
if prop.dim() == 1:
|
||||
prop = prop.unsqueeze(0)
|
||||
parts.append(prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
|
||||
ref = batch["observation.reference_action"]
|
||||
if ref.dim() == 1:
|
||||
ref = ref.unsqueeze(0)
|
||||
|
||||
action = self.actor(state, ref)
|
||||
return action.squeeze(0)
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
@@ -15,11 +15,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -47,13 +52,20 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self.encoder = SACObservationEncoder(config)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_discrete_critic()
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": [self.actor.parameters()],
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
@@ -71,9 +83,10 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
|
||||
observations_features = None
|
||||
if self.encoder.has_images:
|
||||
observations_features = self.encoder.get_cached_image_features(batch)
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -84,35 +97,372 @@ class SACPolicy(
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Actor forward pass."""
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
"""Compute the loss for the given model
|
||||
|
||||
def _init_actor(self, continuous_action_dim: int) -> None:
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder,
|
||||
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=False,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
def _init_discrete_critic(self) -> None:
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder,
|
||||
input_dim=self.encoder.output_dim,
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
@@ -28,7 +28,14 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
@@ -44,7 +51,6 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -88,7 +94,6 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
@@ -99,6 +104,8 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeltaActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -128,6 +135,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_delta_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,54 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class DeltaActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the delta conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to delta).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted deltas are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired DeltaActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.delta_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
|
||||
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.delta_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.delta_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.delta_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,37 +312,6 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -131,15 +131,6 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -158,7 +149,6 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -208,7 +198,6 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -219,7 +208,6 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
@@ -343,11 +331,9 @@ class _NormalizationMixin:
|
||||
)
|
||||
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
return tensor * (std + 1e-6) + mean
|
||||
return (tensor - mean) / (std + 1e-6)
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min", None)
|
||||
@@ -379,11 +365,7 @@ class _NormalizationMixin:
|
||||
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
|
||||
)
|
||||
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
denom = q99 - q01 + 1e-6
|
||||
if inverse:
|
||||
return (tensor + 1.0) * denom / 2.0 + q01
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
+19
-9
@@ -61,7 +61,7 @@ from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
@@ -248,16 +248,16 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
policy = make_policy(
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
# preprocessor, postprocessor = None, None
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -288,6 +288,7 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
@@ -648,12 +649,12 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
"""Load the latest policy weights from the learner."""
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dicts = bytes_to_state_dict(bytes_state_dict)
|
||||
|
||||
# TODO: check encoder parameter synchronization possible issues:
|
||||
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
|
||||
# instead of the updated encoder params from critic (which is optimized separately)
|
||||
@@ -663,9 +664,18 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
state_dicts = move_state_dict_to_device(state_dicts, device=device)
|
||||
policy.load_state_dict(state_dicts)
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
RLAlgorithm,
|
||||
RLAlgorithmConfig,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm(
|
||||
policy: torch.nn.Module,
|
||||
policy_cfg,
|
||||
*,
|
||||
algorithm_name: str,
|
||||
) -> RLAlgorithm:
|
||||
"""Construct an :class:`RLAlgorithm` from a policy and its config.
|
||||
|
||||
Algorithm selection is explicit via ``algorithm_name`` (from
|
||||
``cfg.algorithm``).
|
||||
|
||||
This is fully registry-driven — adding a new algorithm only requires
|
||||
registering an ``RLAlgorithmConfig`` subclass; no changes here.
|
||||
|
||||
The returned algorithm has **no optimizers** yet. On the learner side,
|
||||
call ``algorithm.make_optimizers()`` afterwards to create them. On the
|
||||
actor side (inference-only), leave them empty.
|
||||
|
||||
Args:
|
||||
policy: Instantiated policy (e.g. ``SACPolicy``).
|
||||
policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters
|
||||
expected by the algorithm config's ``from_policy_config`` class-method.
|
||||
algorithm_name: Algorithm registry key to instantiate.
|
||||
"""
|
||||
known = RLAlgorithmConfig.get_known_choices()
|
||||
if algorithm_name not in known:
|
||||
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
|
||||
|
||||
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
|
||||
algo_config = config_cls.from_policy_config(policy_cfg)
|
||||
return algo_config.build_algorithm(policy)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTAlgorithm",
|
||||
"RLTAlgorithmConfig",
|
||||
"make_algorithm",
|
||||
]
|
||||
@@ -1,183 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Base classes for RL algorithms.
|
||||
|
||||
Defines the abstract interface that every algorithm must implement, a registry
|
||||
for algorithm configs, and a dataclass for training statistics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
# Generic containers for all algorithms
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""Construct the :class:`RLAlgorithm` for this config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
"""Whether this algorithm has an offline pretraining phase.
|
||||
|
||||
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
|
||||
return ``True`` here. The learner checks this before the main online
|
||||
loop and routes to :meth:`offline_update` accordingly.
|
||||
"""
|
||||
return False
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One offline training step (called before any online collection).
|
||||
|
||||
Only called when :meth:`supports_offline_phase` returns ``True``.
|
||||
Uses the same iterator protocol as :meth:`update`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not implement offline_update(). "
|
||||
"Either override this method or return False from supports_offline_phase()."
|
||||
)
|
||||
|
||||
def transition_to_online(self) -> None: # noqa: B027
|
||||
"""Called once when switching from offline to online phase.
|
||||
|
||||
Use this to freeze modules trained offline, rebuild optimizers for the
|
||||
online phase, reset step counters, etc.
|
||||
|
||||
Default is a no-op; subclasses override when they have an offline phase.
|
||||
"""
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner (inverse of ``get_weights``)."""
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Pre-compute observation features (e.g. frozen encoder cache).
|
||||
|
||||
Returns ``(None, None)`` when caching is not applicable.
|
||||
"""
|
||||
return None, None
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"]
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""RLT-specific hyper-parameters that control the update loop."""
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
chunk_stride: int = 2
|
||||
|
||||
# ── Update cadence ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
clip_grad_norm: float = 10.0
|
||||
|
||||
# ── Learning rates ──
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
|
||||
# ── TD learning ──
|
||||
discount: float = 0.99
|
||||
tau: float = 0.005
|
||||
num_critics: int = 2
|
||||
|
||||
# ── Policy constraint (paper Eq. 5) ──
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
|
||||
# ── Offline RL-token training ──
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
|
||||
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
|
||||
return cls(
|
||||
chunk_size=policy_cfg.chunk_size,
|
||||
chunk_stride=policy_cfg.chunk_stride,
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.clip_grad_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
rl_token_lr=policy_cfg.rl_token_lr,
|
||||
discount=policy_cfg.discount,
|
||||
tau=policy_cfg.tau,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
bc_reg_coeff=policy_cfg.bc_reg_coeff,
|
||||
ref_dropout=policy_cfg.ref_dropout,
|
||||
vla_finetune_weight=policy_cfg.vla_finetune_weight,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
return RLTAlgorithm(policy=policy, config=self)
|
||||
@@ -1,319 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) algorithm.
|
||||
|
||||
Implements the two-stage training from "RL Token: Bootstrapping Online RL
|
||||
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
|
||||
|
||||
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
|
||||
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
|
||||
reference-action pass-through, and reference-action dropout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
class RLTCritic(nn.Module):
|
||||
"""Q-function over (state, action_chunk) pairs.
|
||||
|
||||
Paper Eq. 3: Q_psi(x, a_{1:C})
|
||||
|
||||
Training-only component — lives on the algorithm side, not in the policy.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
|
||||
super().__init__()
|
||||
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
|
||||
|
||||
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
|
||||
x = torch.cat([state, action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class RLTAlgorithm(RLAlgorithm):
|
||||
"""RL Token: lightweight actor-critic on frozen VLA features.
|
||||
|
||||
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
|
||||
ensemble, and target networks. All VLA-specific logic (embedding
|
||||
extraction, reference actions) lives in ``_prepare_forward_batch``.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._is_online = False
|
||||
|
||||
self._init_critics()
|
||||
self._move_to_device()
|
||||
|
||||
# ── Initialization ───────────────────────────────────────────────
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
state_dim = self.policy._state_dim
|
||||
action_chunk_dim = self.policy._action_chunk_dim
|
||||
hidden_dims = self.policy.config.critic.hidden_dims
|
||||
|
||||
self.critics = torch.nn.ModuleList(
|
||||
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
|
||||
)
|
||||
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
|
||||
for ct in self.critic_targets:
|
||||
ct.requires_grad_(False)
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.critics.to(self._device)
|
||||
self.critic_targets.to(self._device)
|
||||
|
||||
# ── Offline phase (Stage 1): RL-token training ───────────────────
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
return True
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Train RL-token encoder/decoder on demonstration data.
|
||||
|
||||
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
|
||||
"""
|
||||
batch = next(batch_iterator)
|
||||
|
||||
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
|
||||
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
|
||||
|
||||
z_rl = self.policy.rl_token_encoder(z_vla)
|
||||
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
|
||||
|
||||
loss_ro = F.mse_loss(z_reconstructed, z_vla)
|
||||
|
||||
self.optimizers["rl_token"].zero_grad()
|
||||
loss_ro.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
)
|
||||
self.optimizers["rl_token"].step()
|
||||
|
||||
self._optimization_step += 1
|
||||
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
|
||||
|
||||
def transition_to_online(self) -> None:
|
||||
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
|
||||
self.policy.rl_token_encoder.requires_grad_(False)
|
||||
self.policy.rl_token_decoder.requires_grad_(False)
|
||||
self._is_online = True
|
||||
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
self._optimization_step = 0
|
||||
|
||||
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One full RLT update step with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
|
||||
the last batch also updates the actor (every ``policy_update_freq`` steps).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
self._critic_step(fb)
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
critic_loss = self._critic_step(fb)
|
||||
|
||||
stats = TrainingStats(losses={"loss_critic": critic_loss})
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
actor_loss, bc_loss, q_val = self._actor_step(fb)
|
||||
stats.losses["loss_actor"] = actor_loss
|
||||
stats.extra["bc_loss"] = bc_loss
|
||||
stats.extra["q_value_mean"] = q_val
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Convert a replay batch into algorithm-ready tensors.
|
||||
|
||||
Extracts RL-token from VLA embeddings, builds RL state, reads
|
||||
reference action from complementary_info.
|
||||
"""
|
||||
obs = batch["state"]
|
||||
next_obs = batch["next_state"]
|
||||
device = self._device
|
||||
|
||||
vla_emb = obs["observation.vla_embeddings"].to(device)
|
||||
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
z_rl = self.policy.rl_token_encoder(vla_emb)
|
||||
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
|
||||
|
||||
parts = [z_rl]
|
||||
next_parts = [z_rl_next]
|
||||
if "observation.state" in obs and self.policy._proprioception_dim > 0:
|
||||
prop = obs["observation.state"].to(device)
|
||||
next_prop = next_obs["observation.state"].to(device)
|
||||
parts.append(prop)
|
||||
next_parts.append(next_prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
next_state = torch.cat(next_parts, dim=-1)
|
||||
|
||||
action = batch[ACTION].to(device)
|
||||
reward = batch["reward"].to(device)
|
||||
done = batch["done"].to(device)
|
||||
|
||||
ref_action = None
|
||||
comp_info = batch.get("complementary_info")
|
||||
if comp_info is not None and "reference_action" in comp_info:
|
||||
ref_action = comp_info["reference_action"].to(device)
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"next_state": next_state,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"reference_action": ref_action,
|
||||
}
|
||||
|
||||
def _critic_step(self, fb: dict[str, Any]) -> float:
|
||||
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
|
||||
state = fb["state"]
|
||||
next_state = fb["next_state"]
|
||||
action = fb["action"]
|
||||
reward = fb["reward"]
|
||||
done = fb["done"]
|
||||
|
||||
with torch.no_grad():
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(action)
|
||||
next_action = self.policy.actor(next_state, ref)
|
||||
|
||||
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
|
||||
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
|
||||
|
||||
discount_chunk = self.config.discount**self.config.chunk_size
|
||||
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
|
||||
|
||||
q_preds = [c(state, action) for c in self.critics]
|
||||
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
|
||||
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["critic"].step()
|
||||
return loss.item()
|
||||
|
||||
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
|
||||
"""Paper Eq. 5: maximize Q while staying near VLA reference.
|
||||
|
||||
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
|
||||
With reference-action dropout applied to the actor's ref input.
|
||||
"""
|
||||
state = fb["state"]
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
|
||||
|
||||
# Reference-action dropout (paper Section IV-B)
|
||||
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
|
||||
ref_input = ref * mask
|
||||
|
||||
action = self.policy.actor(state, ref_input)
|
||||
|
||||
q_value = self.critics[0](state, action)
|
||||
|
||||
bc_loss = F.mse_loss(action, ref)
|
||||
|
||||
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
|
||||
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
return loss.item(), bc_loss.item(), q_value.mean().item()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.tau
|
||||
for critic, target in zip(self.critics, self.critic_targets, strict=True):
|
||||
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
|
||||
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
|
||||
|
||||
# ── Optimizer management ─────────────────────────────────────────
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create optimizers. Initially for RL-token (Stage 1)."""
|
||||
self.optimizers = {
|
||||
"rl_token": torch.optim.Adam(
|
||||
list(self.policy.rl_token_encoder.parameters())
|
||||
+ list(self.policy.rl_token_decoder.parameters()),
|
||||
lr=self.config.rl_token_lr,
|
||||
),
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
# ── Weight sync ──────────────────────────────────────────────────
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Push actor + RL-token encoder to actors (small footprint)."""
|
||||
weights = {
|
||||
"actor": self.policy.actor.state_dict(),
|
||||
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
|
||||
}
|
||||
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
if "actor" in weights:
|
||||
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
|
||||
if "rl_token_encoder" in weights:
|
||||
self.policy.rl_token_encoder.load_state_dict(
|
||||
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
|
||||
@@ -1,81 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC-specific hyper-parameters that control the update loop."""
|
||||
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
clip_grad_norm: float = 40.0
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
num_discrete_actions: int | None = None
|
||||
shared_encoder: bool = True
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
use_torch_compile: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
|
||||
return cls(
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.grad_clip_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
temperature_lr=policy_cfg.temperature_lr,
|
||||
discount=policy_cfg.discount,
|
||||
temperature_init=policy_cfg.temperature_init,
|
||||
target_entropy=policy_cfg.target_entropy,
|
||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
||||
num_discrete_actions=policy_cfg.num_discrete_actions,
|
||||
shared_encoder=policy_cfg.shared_encoder,
|
||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
use_torch_compile=policy_cfg.use_torch_compile,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
@@ -1,409 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC (Soft Actor-Critic) algorithm.
|
||||
|
||||
This module encapsulates all SAC-specific training logic (critic, actor,
|
||||
temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.sac.modeling_sac import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
CriticEnsemble,
|
||||
CriticHead,
|
||||
DiscreteCritic,
|
||||
SACObservationEncoder,
|
||||
SACPolicy,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic with optional discrete-critic head.
|
||||
|
||||
Owns the ``SACPolicy`` and its optimizers. All loss methods call
|
||||
``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor``
|
||||
directly, so any policy that returns ``{"action", "log_prob"}`` from its
|
||||
``forward()`` is compatible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SACPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._init_critic_encoder()
|
||||
self._init_critics()
|
||||
self._init_temperature()
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critic_encoder(self) -> None:
|
||||
"""Build or share the encoder used by critics."""
|
||||
if self.config.shared_encoder:
|
||||
self.critic_encoder = self.policy.encoder
|
||||
self.policy.actor.encoder_is_shared = True
|
||||
else:
|
||||
self.critic_encoder = SACObservationEncoder(self.policy.config)
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
input_dim = self.critic_encoder.output_dim + action_dim
|
||||
|
||||
heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads)
|
||||
|
||||
target_heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critic_target()
|
||||
|
||||
def _init_discrete_critic_target(self) -> None:
|
||||
"""Build only the target discrete critic."""
|
||||
input_dim = self.critic_encoder.output_dim
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.critic_encoder,
|
||||
input_dim=input_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO: (kmeftah) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and default target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
"""Move algorithm-owned modules to the policy device."""
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if hasattr(self, "discrete_critic_target"):
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Run one full SAC update with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches from ``batch_iterator``. The first
|
||||
``utd_ratio - 1`` batches are used for critic-only warm-up steps;
|
||||
the last batch drives the full update (critic + actor + temperature).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
critic_loss_val = loss_critic.item()
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": critic_loss_val},
|
||||
grad_norms={"critic": critic_grad_norm},
|
||||
)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_discrete.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
actor_loss = self._compute_loss_actor(forward_batch)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
temp_loss = self._compute_loss_temperature(forward_batch)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_(
|
||||
[self.log_alpha],
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = actor_loss.item()
|
||||
stats.losses["loss_temperature"] = temp_loss.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features})
|
||||
next_actions = next_output["action"]
|
||||
next_log_probs = next_output["log_prob"]
|
||||
|
||||
q_targets = self.critic_target(next_observations, next_actions, next_obs_features)
|
||||
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
min_q, _ = q_targets.min(dim=0)
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
actions = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions, obs_features)
|
||||
|
||||
td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete).long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features)
|
||||
best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
target_next_qs = self.discrete_critic_target(next_observations, next_obs_features)
|
||||
target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1)
|
||||
|
||||
rewards_disc = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_disc = rewards + discrete_penalties
|
||||
target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q
|
||||
|
||||
predicted_qs = self.policy.discrete_critic(observations, obs_features)
|
||||
predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
return F.mse_loss(input=predicted_q, target=target_q)
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
actions_pi = output["action"]
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions_pi, obs_features)
|
||||
min_q = q_preds.min(dim=0)[0]
|
||||
|
||||
return ((self.temperature * log_probs) - min_q).mean()
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.critic_target_update_weight
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Build the dict expected by loss computation from a sampled batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create Adam optimizers for the SAC components and store them."""
|
||||
actor_params = [
|
||||
p
|
||||
for n, p in self.policy.actor.named_parameters()
|
||||
if not self.config.shared_encoder or not n.startswith("encoder")
|
||||
]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors (includes actor + discrete critic)."""
|
||||
return move_state_dict_to_device(self.policy.state_dict(), device="cpu")
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
state = move_state_dict_to_device(weights, device=device)
|
||||
self.policy.load_state_dict(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
if not self.config.shared_encoder:
|
||||
return None, None
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
if not self.policy.encoder.has_images:
|
||||
return None, None
|
||||
observation_features = self.policy.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.encoder.get_cached_image_features(next_observations)
|
||||
return observation_features, next_observation_features
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||
|
||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies.
|
||||
|
||||
Subclasses must implement ``sample(batch_size)`` and may override
|
||||
``get_iterator`` for specialised iteration.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
...
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches.
|
||||
|
||||
The default implementation repeatedly calls ``self.sample()``.
|
||||
Subclasses with underlying buffer iterators (async prefetch)
|
||||
should override this for better throughput.
|
||||
"""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an optional offline replay buffer.
|
||||
|
||||
When both buffers are present, each batch is constructed by sampling
|
||||
``ceil(batch_size * online_ratio)`` from the online buffer and the
|
||||
remainder from the offline buffer, then concatenating.
|
||||
|
||||
This mixer assumes both online and offline buffers are present.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer | None = None,
|
||||
online_ratio: float = 1.0,
|
||||
):
|
||||
if not 0.0 <= online_ratio <= 1.0:
|
||||
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
|
||||
self.online_buffer = online_buffer
|
||||
self.offline_buffer = offline_buffer
|
||||
self.online_ratio = online_ratio
|
||||
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
if self.offline_buffer is None:
|
||||
return self.online_buffer.sample(batch_size)
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
n_offline = batch_size - n_online
|
||||
|
||||
online_batch = self.online_buffer.sample(n_online)
|
||||
offline_batch = self.offline_buffer.sample(n_offline)
|
||||
return concatenate_batch_transitions(online_batch, offline_batch)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches from online/offline mixed sampling."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
@@ -36,7 +36,6 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -380,7 +379,6 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -610,14 +608,7 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -665,7 +656,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -734,8 +725,6 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
+284
-92
@@ -65,11 +65,9 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -95,7 +93,7 @@ from lerobot.utils.train_utils import (
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.transition import move_transition_to_device
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
@@ -266,8 +264,8 @@ def add_actor_information_and_train(
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``).
|
||||
- Periodically pushes updated weights to actors.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
@@ -286,15 +284,17 @@ def add_actor_information_and_train(
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.async_prefetch
|
||||
queue_size = cfg.queue_size
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -306,7 +306,7 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
|
||||
policy = make_policy(
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
@@ -315,24 +315,19 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(
|
||||
policy=policy,
|
||||
policy_cfg=cfg.policy,
|
||||
algorithm_name=cfg.algorithm,
|
||||
)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
preprocessor, postprocessor = None, None
|
||||
|
||||
# Push initial policy weights to actors (same path as periodic push)
|
||||
state_bytes = state_to_bytes(algorithm.get_weights())
|
||||
parameters_queue.put(state_bytes)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
total_batch_size = cfg.batch_size
|
||||
batch_size = cfg.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset is not None:
|
||||
@@ -341,70 +336,20 @@ def add_actor_information_and_train(
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
|
||||
# DataMixer: online-only or online/offline 50-50 mix
|
||||
data_mixer = OnlineOfflineMixer(
|
||||
online_buffer=replay_buffer,
|
||||
offline_buffer=offline_replay_buffer,
|
||||
online_ratio=cfg.online_ratio,
|
||||
)
|
||||
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=total_batch_size,
|
||||
preprocessor=preprocessor,
|
||||
action_dim=cfg.policy.output_features["action"].shape[0],
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
algorithm.optimization_step = optimization_step
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
|
||||
offline_steps = getattr(cfg.policy, "offline_steps", 0)
|
||||
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
|
||||
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
|
||||
offline_mixer = OnlineOfflineMixer(
|
||||
online_buffer=offline_replay_buffer,
|
||||
offline_buffer=None,
|
||||
online_ratio=1.0,
|
||||
)
|
||||
offline_iterator = algorithm.configure_data_iterator(
|
||||
data_mixer=offline_mixer,
|
||||
batch_size=total_batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
for step in range(offline_steps):
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
|
||||
return
|
||||
|
||||
stats = algorithm.offline_update(offline_iterator)
|
||||
|
||||
if step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
|
||||
if wandb_logger:
|
||||
log_dict = stats.to_log_dict()
|
||||
log_dict["offline_step"] = step
|
||||
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
|
||||
|
||||
algorithm.transition_to_online()
|
||||
optimizers = algorithm.get_optimizers()
|
||||
logging.info("[LEARNER] Offline phase complete, transitioned to online")
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
@@ -435,22 +380,180 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
if online_iterator is None:
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
|
||||
stats = trainer.training_step()
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
state_dicts = algorithm.get_weights()
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
training_infos = stats.to_log_dict()
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
optimization_step = algorithm.optimization_step
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
@@ -478,6 +581,7 @@ def add_actor_information_and_train(
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
@@ -494,8 +598,6 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,8 +682,6 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
@@ -605,8 +705,6 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
preprocessor: Optional preprocessor pipeline to save
|
||||
postprocessor: Optional postprocessor pipeline to save
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
@@ -623,8 +721,6 @@ def save_training_checkpoint(
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
@@ -662,6 +758,58 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
# Training setup functions
|
||||
|
||||
|
||||
@@ -866,6 +1014,33 @@ def initialize_offline_replay_buffer(
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
@@ -916,6 +1091,23 @@ def check_nan_in_transition(
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType, *, action_dim: int | None = None) -> BatchType:
|
||||
"""Apply a policy preprocessor to an RL batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
actions = batch[ACTION]
|
||||
|
||||
extra_action = None
|
||||
if action_dim is not None and actions.shape[-1] > action_dim:
|
||||
extra_action = actions[..., action_dim:]
|
||||
actions = actions[..., :action_dim]
|
||||
|
||||
obs_action = {**observations, ACTION: actions}
|
||||
obs_action = preprocessor(obs_action)
|
||||
batch["state"] = {k: v for k, v in obs_action.items() if k.startswith("observation.")}
|
||||
batch[ACTION] = obs_action[ACTION]
|
||||
|
||||
if extra_action is not None:
|
||||
batch[ACTION] = torch.cat([batch[ACTION], extra_action], dim=-1)
|
||||
|
||||
next_obs = {**next_observations}
|
||||
next_obs = preprocessor(next_obs)
|
||||
batch["next_state"] = {k: v for k, v in next_obs.items() if k.startswith("observation.")}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class _PreprocessedIterator:
|
||||
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
||||
|
||||
__slots__ = ("_raw", "_preprocessor", "_action_dim")
|
||||
|
||||
def __init__(
|
||||
self, raw_iterator: Iterator[BatchType], preprocessor: Any, action_dim: int | None = None
|
||||
) -> None:
|
||||
self._raw = raw_iterator
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
|
||||
def __iter__(self) -> _PreprocessedIterator:
|
||||
return self
|
||||
|
||||
def __next__(self) -> BatchType:
|
||||
batch = next(self._raw)
|
||||
return preprocess_rl_batch(self._preprocessor, batch, action_dim=self._action_dim)
|
||||
|
||||
|
||||
class RLTrainer:
|
||||
"""Unified training step orchestrator.
|
||||
|
||||
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: RLAlgorithm,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
preprocessor: Any | None = None,
|
||||
action_dim: int | None = None,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
self.algorithm = algorithm
|
||||
self.data_mixer = data_mixer
|
||||
self.batch_size = batch_size
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
self.async_prefetch = async_prefetch
|
||||
self.queue_size = queue_size
|
||||
|
||||
self._iterator: Iterator[BatchType] | None = None
|
||||
|
||||
self.algorithm.make_optimizers()
|
||||
|
||||
def _build_data_iterator(self) -> Iterator[BatchType]:
|
||||
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
||||
raw = self.algorithm.configure_data_iterator(
|
||||
data_mixer=self.data_mixer,
|
||||
batch_size=self.batch_size,
|
||||
async_prefetch=self.async_prefetch,
|
||||
queue_size=self.queue_size,
|
||||
)
|
||||
if self._preprocessor is not None:
|
||||
return _PreprocessedIterator(raw, self._preprocessor, self._action_dim)
|
||||
return raw
|
||||
|
||||
def reset_data_iterator(self) -> None:
|
||||
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
||||
self._iterator = None
|
||||
|
||||
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
||||
"""Swap the active data mixer, optionally resetting the iterator."""
|
||||
self.data_mixer = data_mixer
|
||||
if reset:
|
||||
self.reset_data_iterator()
|
||||
|
||||
def training_step(self) -> TrainingStats:
|
||||
"""Run one training step (algorithm-agnostic)."""
|
||||
if self._iterator is None:
|
||||
self._iterator = self._build_data_iterator()
|
||||
return self.algorithm.update(self._iterator)
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.read_latest()
|
||||
obs[cam_name] = cam.async_read()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -47,14 +47,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--grpc-port 9876
|
||||
--ws-port 9087
|
||||
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
local$ rerun ws://localhost:9087
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -73,7 +75,6 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -92,11 +93,10 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
grpc_port: int = 9876,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,9 +126,7 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -228,7 +226,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -240,13 +238,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -272,7 +265,9 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
action="store_true",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -282,14 +277,6 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -24,107 +24,94 @@ When new_repo_id is specified, creates a new dataset.
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -197,13 +184,6 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
@@ -456,49 +436,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -515,8 +452,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual robot dataset by swapping left/right arms and inverting joint values.
|
||||
|
||||
This script creates a mirrored version of a dataset where:
|
||||
1. Left and right arm observations/actions are swapped
|
||||
2. Joint values are inverted according to a mirroring mask
|
||||
3. Video frames are horizontally flipped
|
||||
|
||||
Example usage:
|
||||
```shell
|
||||
python -m lerobot.scripts.lerobot_mirror_dataset \
|
||||
--repo_id=pepijn/openarm_bimanual \
|
||||
--output_repo_id=pepijn/openarm_bimanual_mirrored
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
DEFAULT_DATA_PATH,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1, # Pan - invert
|
||||
"joint_2": -1, # Lift - invert
|
||||
"joint_3": -1, # Roll - invert
|
||||
"joint_4": 1, # Elbow - no invert
|
||||
"joint_5": -1, # W-Roll - invert
|
||||
"joint_6": -1, # W-Pitch - invert
|
||||
"joint_7": -1, # W-Yaw - invert
|
||||
"gripper": 1, # Gripper - no invert
|
||||
}
|
||||
|
||||
|
||||
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
|
||||
"""Get the mirroring mask for a given robot type."""
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return OPENARM_MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
"""Swap 'left' and 'right' in a feature name."""
|
||||
# Use placeholder to avoid double-swap
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
"""Mirror feature names by swapping left/right and return the new names and index mapping."""
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
|
||||
def apply_mirroring_mask(
|
||||
value: float,
|
||||
feature_name: str,
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> float:
|
||||
"""Apply mirroring mask to a joint value."""
|
||||
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
|
||||
joint_name = name_without_prefix.split(".")[0]
|
||||
if joint_name in mirroring_mask:
|
||||
return value * mirroring_mask[joint_name]
|
||||
return value
|
||||
|
||||
|
||||
def mirror_array(
|
||||
array: np.ndarray,
|
||||
names: list[str],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> np.ndarray:
|
||||
"""Mirror an array of values (action or state) by swapping left/right and applying mask."""
|
||||
mirrored_names, idx_mapping = mirror_feature_names(names)
|
||||
result = np.zeros_like(array)
|
||||
for old_idx, new_idx in idx_mapping.items():
|
||||
old_name = names[old_idx]
|
||||
new_name = mirrored_names[new_idx]
|
||||
value = array[old_idx]
|
||||
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
|
||||
result[new_idx] = mirrored_value
|
||||
return result
|
||||
|
||||
|
||||
def flip_video_frames(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
"""Flip video frames horizontally using FFmpeg with same settings as encode_video_frames."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", str(input_path),
|
||||
"-vf", "hflip",
|
||||
"-c:v", vcodec,
|
||||
"-g", "2",
|
||||
"-crf", "30",
|
||||
"-r", str(int(fps)),
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-loglevel", "error",
|
||||
]
|
||||
if vcodec == "libsvtav1":
|
||||
cmd.extend(["-preset", "12"])
|
||||
cmd.append(str(output_path))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
|
||||
|
||||
|
||||
def mirror_dataset(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
output_root: str | Path | None = None,
|
||||
mirroring_mask: dict[str, int] | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Mirror a bimanual robot dataset."""
|
||||
logger.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if mirroring_mask is None:
|
||||
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
|
||||
mirroring_mask = get_mirroring_mask(robot_type)
|
||||
logger.info(f"Using mirroring mask for robot type: {robot_type}")
|
||||
|
||||
output_root = Path(output_root) if output_root else HF_LEROBOT_HOME / output_repo_id
|
||||
|
||||
mirrored_features = {}
|
||||
for key, feat in dataset.meta.features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feat = feat.copy()
|
||||
if "names" in new_feat and new_feat["names"]:
|
||||
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
|
||||
mirrored_features[new_key] = new_feat
|
||||
|
||||
logger.info("Creating mirrored dataset metadata...")
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=output_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_root,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
new_meta.tasks = dataset.meta.tasks.copy()
|
||||
|
||||
_mirror_data(dataset, new_meta, mirroring_mask)
|
||||
_mirror_videos(dataset, new_meta, vcodec, num_workers)
|
||||
_copy_episodes_metadata(dataset, new_meta)
|
||||
|
||||
logger.info(f"Mirrored dataset saved to: {output_root}")
|
||||
return LeRobotDataset(output_repo_id, root=output_root)
|
||||
|
||||
|
||||
def _mirror_data(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> None:
|
||||
"""Mirror parquet data files."""
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
action_names = src_dataset.meta.features.get("action", {}).get("names", [])
|
||||
state_names = src_dataset.meta.features.get("observation.state", {}).get("names", [])
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Mirroring data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
if "action" in df.columns and action_names:
|
||||
actions = np.stack(df["action"].values)
|
||||
mirrored_actions = np.array([
|
||||
mirror_array(row, action_names, mirroring_mask) for row in actions
|
||||
])
|
||||
df["action"] = list(mirrored_actions)
|
||||
|
||||
if "observation.state" in df.columns and state_names:
|
||||
states = np.stack(df["observation.state"].values)
|
||||
mirrored_states = np.array([
|
||||
mirror_array(row, state_names, mirroring_mask) for row in states
|
||||
])
|
||||
df["observation.state"] = list(mirrored_states)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def _mirror_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
vcodec: str,
|
||||
num_workers: int | None = None,
|
||||
) -> None:
|
||||
"""Mirror video files by flipping horizontally and swapping left/right names."""
|
||||
if not src_dataset.meta.video_keys:
|
||||
return
|
||||
|
||||
video_tasks = []
|
||||
for old_video_key in src_dataset.meta.video_keys:
|
||||
new_video_key = swap_left_right_name(old_video_key)
|
||||
for ep_idx in range(src_dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = src_dataset.root / src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative = src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
|
||||
dst_path = dst_meta.root / dst_relative_str
|
||||
if src_path.exists():
|
||||
video_tasks.append((src_path, dst_path))
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
def process_video(task, pbar):
|
||||
src_path, dst_path = task
|
||||
pbar.set_postfix_str(src_path.name)
|
||||
flip_video_frames(src_path, dst_path, src_dataset.meta.fps, vcodec)
|
||||
return src_path
|
||||
|
||||
if num_workers is None:
|
||||
num_workers = os.cpu_count() or 16
|
||||
num_workers = min(len(video_tasks), num_workers)
|
||||
logger.info(f"Processing {len(video_tasks)} videos with {num_workers} workers")
|
||||
with tqdm(total=len(video_tasks), desc="Mirroring videos") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(process_video, t, pbar): t for t in video_tasks}
|
||||
for future in as_completed(futures):
|
||||
task = futures[future]
|
||||
future.result()
|
||||
pbar.set_postfix_str(f"done: {task[0].name}")
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
def _copy_episodes_metadata(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
) -> None:
|
||||
"""Copy episodes metadata with swapped video keys."""
|
||||
episodes_dir = src_dataset.root / "meta/episodes"
|
||||
dst_episodes_dir = dst_meta.root / "meta/episodes"
|
||||
|
||||
if episodes_dir.exists():
|
||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
for src_parquet in episodes_dir.glob("*/*.parquet"):
|
||||
df = pd.read_parquet(src_parquet)
|
||||
columns_to_rename = {}
|
||||
for col in df.columns:
|
||||
if col.startswith("videos/"):
|
||||
parts = col.split("/")
|
||||
if len(parts) >= 2:
|
||||
video_key = parts[1]
|
||||
new_video_key = swap_left_right_name(video_key)
|
||||
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
|
||||
columns_to_rename[col] = new_col
|
||||
if columns_to_rename:
|
||||
df = df.rename(columns=columns_to_rename)
|
||||
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
|
||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_parquet, index=False)
|
||||
|
||||
dst_meta.info.update({
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"total_videos": src_dataset.meta.total_videos,
|
||||
"total_chunks": src_dataset.meta.total_chunks,
|
||||
})
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if src_dataset.meta.stats is not None:
|
||||
mirrored_stats = _mirror_stats(src_dataset.meta.stats)
|
||||
write_stats(mirrored_stats, dst_meta.root)
|
||||
|
||||
|
||||
def _mirror_stats(stats: dict) -> dict:
|
||||
"""Mirror stats by swapping left/right in feature names."""
|
||||
mirrored = {}
|
||||
for key, value in stats.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
if isinstance(value, dict):
|
||||
mirrored[new_key] = _mirror_stats(value)
|
||||
else:
|
||||
mirrored[new_key] = value
|
||||
return mirrored
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset")
|
||||
parser.add_argument("--repo_id", type=str, required=True, help="Source dataset repo_id")
|
||||
parser.add_argument("--output_repo_id", type=str, required=True, help="Output dataset repo_id")
|
||||
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
|
||||
parser.add_argument("--output_root", type=str, default=None, help="Output dataset root directory")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec (libsvtav1, h264, hevc)")
|
||||
parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers for video processing")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = mirror_dataset(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
root=args.root,
|
||||
output_root=args.output_root,
|
||||
vcodec=args.vcodec,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
if getattr(args, "push_to_hub", False):
|
||||
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -398,14 +398,7 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
|
||||
@@ -175,8 +175,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
@@ -211,16 +209,98 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
delta_action_stats = None
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Compute delta action stats BEFORE distributed sync to avoid NCCL timeout
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
|
||||
|
||||
chunk_size = cfg.policy.chunk_size
|
||||
hf = dataset.hf_dataset
|
||||
total_frames = len(hf)
|
||||
sample_upper_bound = total_frames - chunk_size
|
||||
if sample_upper_bound <= 0:
|
||||
raise ValueError(
|
||||
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
|
||||
)
|
||||
|
||||
max_samples = min(100_000, sample_upper_bound)
|
||||
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
|
||||
|
||||
action_names = dataset.meta.features.get("action", {}).get("names")
|
||||
delta_mask_step = DeltaActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
|
||||
action_names=action_names,
|
||||
)
|
||||
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
|
||||
logging.info(
|
||||
f"use_delta_actions is enabled — computing delta action stats "
|
||||
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
||||
)
|
||||
|
||||
all_delta_actions = []
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
ep_idx = episode_indices[idx]
|
||||
end_idx = min(idx + chunk_size, total_frames)
|
||||
if end_idx > idx and episode_indices[end_idx - 1] != ep_idx:
|
||||
continue
|
||||
|
||||
chunk_data = hf[idx:end_idx]
|
||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
|
||||
all_delta_actions.append(delta.numpy())
|
||||
|
||||
if not all_delta_actions:
|
||||
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
|
||||
|
||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
delta_action_stats = delta_stats
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
norm_type = "UNKNOWN"
|
||||
if hasattr(cfg.policy, "normalization_mapping"):
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
||||
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
||||
|
||||
excluded_dims = len(delta_mask) - sum(delta_mask)
|
||||
logging.info(
|
||||
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
||||
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
||||
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
||||
)
|
||||
if norm_type == "QUANTILES":
|
||||
q_range = (delta_stats['q99'] - delta_stats['q01']).mean()
|
||||
logging.info(f" Quantile range (q99-q01): {q_range:.4f}")
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Ensure all ranks use the exact same delta action stats.
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
|
||||
stats_list = [delta_action_stats]
|
||||
torch.distributed.broadcast_object_list(stats_list, src=0)
|
||||
delta_action_stats = stats_list[0]
|
||||
if delta_action_stats is not None:
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
@@ -246,10 +326,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_delta_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_delta_actions=true with pretrained processors can skip delta transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -257,7 +349,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -279,7 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
@@ -397,7 +489,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
# Debug logging for first few steps and periodically
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None and state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] PRE-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f} | "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
|
||||
)
|
||||
|
||||
batch = preprocessor(batch)
|
||||
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f}"
|
||||
)
|
||||
if state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
|
||||
)
|
||||
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
|
||||
@@ -16,14 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
|
||||
@@ -95,7 +95,6 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.optim.schedulers import (
|
||||
@@ -40,10 +38,6 @@ def test_diffuser_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -62,10 +56,6 @@ def test_vqbet_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -86,10 +76,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
"""Tests for delta action transforms — full pipeline validation.
|
||||
|
||||
Tests the complete flow matching OpenPI:
|
||||
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
|
||||
|
||||
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dataset():
|
||||
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def action_dim(dataset):
|
||||
return dataset.meta.features["action"]["shape"][0]
|
||||
|
||||
|
||||
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||
"""Build action chunks from hf_dataset, like the training script does."""
|
||||
hf = dataset.hf_dataset
|
||||
total = len(hf)
|
||||
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||
chunks, states = [], []
|
||||
for i in range(total - chunk_size + 1):
|
||||
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||
continue
|
||||
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||
state = hf[i]["observation.state"].float()
|
||||
chunks.append(chunk_actions)
|
||||
states.append(state)
|
||||
if len(chunks) >= max_chunks:
|
||||
break
|
||||
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||
return torch.stack(chunks), torch.stack(states)
|
||||
|
||||
|
||||
def _compute_delta_chunk_stats(action_chunks, states, mask):
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta.numpy())
|
||||
all_delta = np.concatenate(all_deltas, axis=0)
|
||||
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
|
||||
|
||||
# --- Basic roundtrip tests ---
|
||||
|
||||
def test_roundtrip_3d(action_dim):
|
||||
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(action_dim):
|
||||
actions = torch.randn(4, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_no_mutation(action_dim):
|
||||
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||
original = actions.clone()
|
||||
state = torch.randn(2, action_dim)
|
||||
to_delta_actions(actions, state, [True] * action_dim)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_exclude_joints_supports_partial_name_matching():
|
||||
names = [
|
||||
"right_joint_1.pos",
|
||||
"right_gripper.pos",
|
||||
"left_joint_1.pos",
|
||||
"left_gripper.pos",
|
||||
]
|
||||
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||
|
||||
|
||||
# --- Chunk-level delta stats test ---
|
||||
|
||||
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Per-frame stats
|
||||
hf = dataset.hf_dataset
|
||||
n = min(500, len(hf))
|
||||
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
|
||||
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
|
||||
|
||||
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
|
||||
|
||||
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized_action = t2[TransitionKey.ACTION]
|
||||
assert normalized_action.abs().mean() < 10, (
|
||||
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered_actions = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
mean = torch.tensor(delta_stats["mean"]).float()
|
||||
std = torch.tensor(delta_stats["std"]).float()
|
||||
|
||||
all_normalized = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
normalized = (delta - mean) / (std + 1e-6)
|
||||
all_normalized.append(normalized)
|
||||
|
||||
all_normalized = torch.cat(all_normalized, dim=0)
|
||||
|
||||
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||
assert pct_in_range > 0.9, (
|
||||
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||
)
|
||||
|
||||
assert all_normalized.mean().abs() < 1.0, (
|
||||
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
|
||||
)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset, action_dim):
|
||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
delta_transition = step(transition)
|
||||
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||
"""enabled=False should be a no-op."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||
}
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
|
||||
|
||||
# --- Training batch shape validation ---
|
||||
|
||||
def test_delta_with_action_chunks(dataset, action_dim):
|
||||
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||
batch_states = states[:4] # (4, state_dim)
|
||||
|
||||
mask = [True] * action_dim
|
||||
delta = to_delta_actions(batch_actions, batch_states, mask)
|
||||
|
||||
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||
first_deltas = delta[:, 0, :] # (B, action_dim)
|
||||
assert first_deltas.abs().mean() < delta.abs().mean(), (
|
||||
f"First action in chunk should have smaller delta than average. "
|
||||
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Later actions should have larger deltas
|
||||
last_deltas = delta[:, -1, :] # (B, action_dim)
|
||||
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
|
||||
f"Last action in chunk should have >= delta than first. "
|
||||
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Roundtrip
|
||||
recovered = to_absolute_actions(delta, batch_states, mask)
|
||||
torch.testing.assert_close(recovered, batch_actions)
|
||||
|
||||
|
||||
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
|
||||
"""Verify computed stats match the actual delta distribution."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
# Compute stats like the training script does
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Also compute directly
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta)
|
||||
all_deltas_tensor = torch.cat(all_deltas, dim=0)
|
||||
|
||||
# Compare mean
|
||||
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
|
||||
|
||||
# Compare std
|
||||
actual_std = all_deltas_tensor.std(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
|
||||
|
||||
# Verify q01 < mean < q99
|
||||
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
|
||||
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
|
||||
|
||||
|
||||
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → quantile normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized = t2[TransitionKey.ACTION]
|
||||
# Most values should be in [-1, 1] with quantile normalization
|
||||
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||
assert pct_in_range > 0.5, (
|
||||
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_state_not_modified_by_delta(dataset, action_dim):
|
||||
"""State should never be modified by the delta processor."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_state = batch[OBS_STATE].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
result = step(transition)
|
||||
|
||||
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
torch.testing.assert_close(result_state, original_state)
|
||||
+208
-189
@@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -21,7 +23,6 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
@@ -137,6 +138,41 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
@@ -176,6 +212,7 @@ def create_config_with_visual_input(
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
@@ -183,112 +220,75 @@ def create_config_with_visual_input(
|
||||
return config
|
||||
|
||||
|
||||
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
"""Helper to create policy + algorithm pair for tests that need critics."""
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
return algorithm, policy
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
# squeeze(0) removes batch dim when batch_size==1
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
def test_sac_policy_select_action_with_discrete():
|
||||
"""select_action should return continuous + discrete actions."""
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
|
||||
# Squeeze to unbatched (single observation)
|
||||
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
with torch.no_grad():
|
||||
output = policy.forward(batch)
|
||||
assert "action" in output
|
||||
assert "log_prob" in output
|
||||
assert "action_mean" in output
|
||||
assert output["action"].shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
temp_loss = algorithm._compute_loss_temperature(forward_batch)
|
||||
assert temp_loss.item() is not None
|
||||
assert temp_loss.shape == ()
|
||||
algorithm.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
algorithm.optimizers["temperature"].step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
policy.train()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
@@ -296,181 +296,207 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_training_with_pretrained_encoder(
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_training_with_shared_encoder():
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
policy.train()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_training_with_discrete_critic():
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
config.num_discrete_actions = 5
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
policy.train()
|
||||
|
||||
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
algorithm.optimizers["discrete_critic"].zero_grad()
|
||||
discrete_critic_loss.backward()
|
||||
algorithm.optimizers["discrete_critic"].step()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
# Policy.select_action now handles both continuous + discrete
|
||||
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
|
||||
assert selected_action.shape[-1] == continuous_action_dim + 1
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy():
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
_, policy = _make_algorithm(config)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -5.0
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_algorithm_temperature():
|
||||
import math
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
"""Test that temperature property is always in sync with log_alpha."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
assert algorithm.temperature == pytest.approx(1.0)
|
||||
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
assert algorithm.temperature == pytest.approx(0.1)
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
# Temperature property automatically reflects log_alpha changes
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_algorithm_update_target_network():
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
for p in algorithm.critic_ensemble.parameters():
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
algorithm._update_target_networks()
|
||||
for p in algorithm.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data))
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
@@ -484,41 +510,34 @@ def test_sac_policy_save_and_load(tmp_path):
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert loaded_policy.discrete_critic is not None
|
||||
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
@@ -23,9 +23,8 @@ import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -297,172 +296,3 @@ def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
assert received_params.keys() == input_params.keys()
|
||||
for key in input_params:
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression test: learner algorithm integration (no gRPC required)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_learner_algorithm_wiring():
|
||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||
make_optimizers() creates the right optimizers, update() works, and
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
optimizers = algorithm.make_optimizers()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
assert "temperature" in optimizers
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def batch_iterator():
|
||||
while True:
|
||||
yield {
|
||||
ACTION: torch.randn(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"done": torch.zeros(batch_size),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
stats = algorithm.update(batch_iterator())
|
||||
assert "critic" in stats.losses
|
||||
|
||||
# get_weights -> state_to_bytes round-trip
|
||||
weights = algorithm.get_weights()
|
||||
assert len(weights) > 0
|
||||
serialized = state_to_bytes(weights)
|
||||
assert isinstance(serialized, bytes)
|
||||
assert len(serialized) > 0
|
||||
|
||||
# RLTrainer with DataMixer
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=50,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for _ in range(50):
|
||||
replay_buffer.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(action_dim),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=batch_size,
|
||||
async_prefetch=False,
|
||||
)
|
||||
trainer_stats = trainer.training_step()
|
||||
assert "critic" in trainer_stats.losses
|
||||
|
||||
|
||||
def test_initial_and_periodic_weight_push_consistency():
|
||||
"""Both initial and periodic weight pushes should use algorithm.get_weights()
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm.make_optimizers()
|
||||
|
||||
# Simulate initial push (same code path the learner now uses)
|
||||
initial_weights = algorithm.get_weights()
|
||||
initial_bytes = state_to_bytes(initial_weights)
|
||||
|
||||
# Simulate periodic push
|
||||
periodic_weights = algorithm.get_weights()
|
||||
periodic_bytes = state_to_bytes(periodic_weights)
|
||||
|
||||
initial_decoded = bytes_to_state_dict(initial_bytes)
|
||||
periodic_decoded = bytes_to_state_dict(periodic_bytes)
|
||||
|
||||
assert initial_decoded.keys() == periodic_decoded.keys()
|
||||
|
||||
|
||||
def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
# Actor side: no optimizers
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
# select_action should work
|
||||
obs = {OBS_STATE: torch.randn(state_dim)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
# Simulate receiving weights from learner
|
||||
fake_weights = algorithm.get_weights()
|
||||
algorithm.load_weights(fake_weights, device="cpu")
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
|
||||
buf = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for i in range(capacity):
|
||||
buf.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(2),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=bool(i % 10 == 9),
|
||||
truncated=False,
|
||||
)
|
||||
return buf
|
||||
|
||||
|
||||
def test_online_only_mixer_sample():
|
||||
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
|
||||
batch = mixer.sample(batch_size=8)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 8
|
||||
assert batch["action"].shape[0] == 8
|
||||
assert batch["reward"].shape[0] == 8
|
||||
|
||||
|
||||
def test_online_only_mixer_ratio_one():
|
||||
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_sample():
|
||||
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
|
||||
online = _make_buffer(capacity=50)
|
||||
offline = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(
|
||||
online_buffer=online,
|
||||
offline_buffer=offline,
|
||||
online_ratio=0.5,
|
||||
)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
assert batch["action"].shape[0] == 10
|
||||
# 5 from online, 5 from offline (approx)
|
||||
assert batch["reward"].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_iterator():
|
||||
"""get_iterator yields batches of the requested size."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
|
||||
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
|
||||
batch1 = next(it)
|
||||
batch2 = next(it)
|
||||
assert batch1["state"][OBS_STATE].shape[0] == 4
|
||||
assert batch2["state"][OBS_STATE].shape[0] == 4
|
||||
@@ -1,477 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig, TrainingStats
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (reuse patterns from tests/policies/test_sac_policy.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
set_seed(42)
|
||||
|
||||
|
||||
def _make_sac_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
num_discrete_actions: int | None = None,
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
with_images: bool = False,
|
||||
) -> SACConfig:
|
||||
config = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
use_torch_compile=False,
|
||||
)
|
||||
if with_images:
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1).tolist(),
|
||||
"std": torch.randn(3, 1, 1).abs().tolist(),
|
||||
}
|
||||
config.latent_dim = 32
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_algorithm(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
num_discrete_actions: int | None = None,
|
||||
with_images: bool = False,
|
||||
) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
sac_cfg = _make_sac_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
with_images=with_images,
|
||||
)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
return algorithm, policy
|
||||
|
||||
|
||||
def _make_batch(
|
||||
batch_size: int = 4,
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
with_images: bool = False,
|
||||
) -> dict:
|
||||
obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
|
||||
next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
|
||||
if with_images:
|
||||
obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
|
||||
next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
|
||||
return {
|
||||
ACTION: torch.randn(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": obs,
|
||||
"next_state": next_obs,
|
||||
"done": torch.zeros(batch_size),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def _batch_iterator(**batch_kwargs):
|
||||
"""Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator)."""
|
||||
while True:
|
||||
yield _make_batch(**batch_kwargs)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Registry / config tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_sac_algorithm_config_registered():
|
||||
"""SACAlgorithmConfig should be discoverable through the registry."""
|
||||
assert "sac" in RLAlgorithmConfig.get_known_choices()
|
||||
cls = RLAlgorithmConfig.get_choice_class("sac")
|
||||
assert cls is SACAlgorithmConfig
|
||||
|
||||
|
||||
def test_sac_algorithm_config_from_policy_config():
|
||||
"""from_policy_config should copy relevant fields."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
assert algo_cfg.utd_ratio == 4
|
||||
assert algo_cfg.policy_update_freq == 2
|
||||
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TrainingStats tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_training_stats_defaults():
|
||||
stats = TrainingStats()
|
||||
assert stats.losses == {}
|
||||
assert stats.grad_norms == {}
|
||||
assert stats.extra == {}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# get_weights
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_get_weights_returns_policy_state_dict():
|
||||
algorithm, policy = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key in policy.state_dict():
|
||||
assert key in weights
|
||||
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
|
||||
|
||||
|
||||
def test_get_weights_includes_discrete_critic_when_present():
|
||||
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
|
||||
def test_get_weights_excludes_discrete_critic_when_absent():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) == 0
|
||||
|
||||
|
||||
def test_get_weights_are_on_cpu():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key, tensor in weights.items():
|
||||
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# select_action (lives on the policy, not the algorithm)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_select_action_returns_correct_shape():
|
||||
action_dim = 6
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=action_dim)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
|
||||
def test_select_action_with_discrete_critic():
|
||||
continuous_dim = 5
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (continuous_dim + 1,)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# update (single batch, utd_ratio=1)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_update_returns_training_stats():
|
||||
algorithm, _ = _make_algorithm()
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert isinstance(stats, TrainingStats)
|
||||
assert "critic" in stats.losses
|
||||
assert isinstance(stats.losses["critic"], float)
|
||||
|
||||
|
||||
def test_update_populates_actor_and_temperature_losses():
|
||||
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
|
||||
algorithm, _ = _make_algorithm(policy_update_freq=1)
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert "actor" in stats.losses
|
||||
assert "temperature" in stats.losses
|
||||
assert "temperature" in stats.extra
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_update_freq", [2, 3])
|
||||
def test_update_skips_actor_at_non_update_steps(policy_update_freq):
|
||||
"""Actor/temperature should only update when optimization_step % freq == 0."""
|
||||
algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq)
|
||||
it = _batch_iterator()
|
||||
|
||||
# Step 0: should update actor
|
||||
stats_0 = algorithm.update(it)
|
||||
assert "actor" in stats_0.losses
|
||||
|
||||
# Step 1: should NOT update actor
|
||||
stats_1 = algorithm.update(it)
|
||||
assert "actor" not in stats_1.losses
|
||||
|
||||
|
||||
def test_update_increments_optimization_step():
|
||||
algorithm, _ = _make_algorithm()
|
||||
it = _batch_iterator()
|
||||
assert algorithm.optimization_step == 0
|
||||
algorithm.update(it)
|
||||
assert algorithm.optimization_step == 1
|
||||
algorithm.update(it)
|
||||
assert algorithm.optimization_step == 2
|
||||
|
||||
|
||||
def test_update_with_discrete_critic():
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
|
||||
assert "discrete_critic" in stats.losses
|
||||
assert "discrete_critic" in stats.grad_norms
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# update with UTD ratio > 1
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("utd_ratio", [2, 4])
|
||||
def test_update_with_utd_ratio(utd_ratio):
|
||||
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert isinstance(stats, TrainingStats)
|
||||
assert "critic" in stats.losses
|
||||
assert algorithm.optimization_step == 1
|
||||
|
||||
|
||||
def test_update_utd_ratio_pulls_utd_batches():
|
||||
"""next(batch_iterator) should be called exactly utd_ratio times."""
|
||||
utd_ratio = 3
|
||||
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def counting_iterator():
|
||||
nonlocal call_count
|
||||
while True:
|
||||
call_count += 1
|
||||
yield _make_batch()
|
||||
|
||||
algorithm.update(counting_iterator())
|
||||
assert call_count == utd_ratio
|
||||
|
||||
|
||||
def test_update_utd_ratio_3_critic_warmup_changes_weights():
|
||||
"""With utd_ratio=3, critic weights should change after update (3 critic steps)."""
|
||||
algorithm, policy = _make_algorithm(utd_ratio=3)
|
||||
|
||||
critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()}
|
||||
|
||||
algorithm.update(_batch_iterator())
|
||||
|
||||
changed = False
|
||||
for n, p in algorithm.critic_ensemble.named_parameters():
|
||||
if not torch.equal(p, critic_params_before[n]):
|
||||
changed = True
|
||||
break
|
||||
assert changed, "Critic weights should have changed after UTD update"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# get_observation_features
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_get_observation_features_returns_none_without_frozen_encoder():
|
||||
algorithm, _ = _make_algorithm(with_images=False)
|
||||
obs = {OBS_STATE: torch.randn(4, 10)}
|
||||
next_obs = {OBS_STATE: torch.randn(4, 10)}
|
||||
feat, next_feat = algorithm.get_observation_features(obs, next_obs)
|
||||
assert feat is None
|
||||
assert next_feat is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# optimization_step setter
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_optimization_step_can_be_set_for_resume():
|
||||
algorithm, _ = _make_algorithm()
|
||||
algorithm.optimization_step = 100
|
||||
assert algorithm.optimization_step == 100
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# make_algorithm factory
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
optimizers = algorithm.make_optimizers()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
assert "temperature" in optimizers
|
||||
assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values())
|
||||
assert algorithm.get_optimizers() is optimizers
|
||||
|
||||
|
||||
def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
def test_make_algorithm_copies_config_fields():
|
||||
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert algorithm.config.utd_ratio == 5
|
||||
assert algorithm.config.policy_update_freq == 3
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
class FakeConfig:
|
||||
type = "unknown_algo"
|
||||
|
||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# load_weights (round-trip with get_weights)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_load_weights_round_trip():
|
||||
"""get_weights -> load_weights should restore identical parameters on a fresh policy."""
|
||||
algo_src, _ = _make_algorithm(state_dim=10, action_dim=6)
|
||||
algo_src.update(_batch_iterator())
|
||||
|
||||
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
for key in weights:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Policy param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_round_trip_with_discrete_critic():
|
||||
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
for key in dc_keys:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_ignores_missing_discrete_critic():
|
||||
"""load_weights should not fail when weights lack discrete_critic on a non-discrete policy."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
algorithm.load_weights(weights, device="cpu")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TrainingStats generic losses dict
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_training_stats_generic_losses():
|
||||
stats = TrainingStats(
|
||||
losses={"loss_bc": 0.5, "loss_q": 1.2},
|
||||
extra={"temperature": 0.1},
|
||||
)
|
||||
assert stats.losses["loss_bc"] == 0.5
|
||||
assert stats.losses["loss_q"] == 1.2
|
||||
assert stats.extra["temperature"] == 0.1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Registry-driven build_algorithm
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_build_algorithm_via_config():
|
||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=2)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
|
||||
algorithm = algo_config.build_algorithm(policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.config.utd_ratio == 2
|
||||
|
||||
|
||||
def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
@@ -1,115 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm, TrainingStats
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class _CountingAlgorithm(RLAlgorithm):
|
||||
def __init__(self):
|
||||
self.configure_calls = 0
|
||||
self.update_calls = 0
|
||||
|
||||
def select_action(self, observation: dict[str, Tensor]) -> Tensor:
|
||||
return torch.zeros(1)
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
self.configure_calls += 1
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers(self):
|
||||
return {}
|
||||
|
||||
def update(self, batch_iterator):
|
||||
self.update_calls += 1
|
||||
_ = next(batch_iterator)
|
||||
return TrainingStats(losses={"dummy": 1.0})
|
||||
|
||||
def load_weights(self, weights, device="cpu") -> None:
|
||||
_ = (weights, device)
|
||||
|
||||
|
||||
class _SimpleMixer:
|
||||
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
|
||||
_ = (async_prefetch, queue_size)
|
||||
while True:
|
||||
yield {
|
||||
"state": {OBS_STATE: torch.randn(batch_size, 3)},
|
||||
ACTION: torch.randn(batch_size, 2),
|
||||
"reward": torch.randn(batch_size),
|
||||
"next_state": {OBS_STATE: torch.randn(batch_size, 3)},
|
||||
"done": torch.zeros(batch_size),
|
||||
"truncated": torch.zeros(batch_size),
|
||||
"complementary_info": None,
|
||||
}
|
||||
|
||||
|
||||
def test_trainer_lazy_iterator_lifecycle_and_reset():
|
||||
algo = _CountingAlgorithm()
|
||||
mixer = _SimpleMixer()
|
||||
trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4, async_prefetch=False)
|
||||
|
||||
# First call builds iterator once.
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
assert algo.update_calls == 1
|
||||
|
||||
# Second call reuses existing iterator.
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
assert algo.update_calls == 2
|
||||
|
||||
# Explicit reset forces lazy rebuild on next step.
|
||||
trainer.reset_data_iterator()
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 2
|
||||
assert algo.update_calls == 3
|
||||
|
||||
|
||||
def test_trainer_set_data_mixer_resets_by_default():
|
||||
algo = _CountingAlgorithm()
|
||||
mixer_a = _SimpleMixer()
|
||||
mixer_b = _SimpleMixer()
|
||||
trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2, async_prefetch=False)
|
||||
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
|
||||
trainer.set_data_mixer(mixer_b, reset=True)
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 2
|
||||
|
||||
|
||||
def test_algorithm_optimization_step_contract_defaults():
|
||||
algo = _CountingAlgorithm()
|
||||
assert algo.optimization_step == 0
|
||||
algo.optimization_step = 11
|
||||
assert algo.optimization_step == 11
|
||||
@@ -142,7 +142,6 @@ def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
ConvertImageToVideoConfig,
|
||||
DeleteEpisodesConfig,
|
||||
EditDatasetConfig,
|
||||
InfoConfig,
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
@@ -47,7 +46,6 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||
@@ -65,7 +63,6 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||
|
||||
Reference in New Issue
Block a user