mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
152 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 86e7302e10 | |||
| 0394fae446 | |||
| 602b8e66a6 | |||
| ab4dce6fed | |||
| 40f4386e4a | |||
| 87a91b4b08 | |||
| fadb900c36 | |||
| de0663226a | |||
| 0ca9d66cae | |||
| 2222f25da3 | |||
| acae8417aa | |||
| 2697f65cf6 | |||
| 74f42f218e | |||
| ca9d49e305 | |||
| 6705876d47 | |||
| aadbd27675 | |||
| 5221647b5e | |||
| 9c981300dd | |||
| f5b27aad1b | |||
| 75f1285507 | |||
| 33cedc2f71 | |||
| aa32e6c4ab | |||
| f906270ec4 | |||
| 733b6d84db | |||
| 8abc9037a3 | |||
| e4d4ac0bda | |||
| e79b2a439b | |||
| f9ae78ca74 | |||
| e1ced538e3 | |||
| 2a98602ad6 | |||
| a2f5b3571e | |||
| cecf2eff4f | |||
| 7e6b598a51 | |||
| 4fa41ba806 | |||
| 1de2b87a92 | |||
| 3ec7c25e7d | |||
| e3c511db67 | |||
| aed4130d39 | |||
| d26349c692 | |||
| a9bce4732b | |||
| 86d69e3c1d | |||
| 2d8ac028f9 | |||
| ec1de9c9e3 | |||
| 1ea040fe8c | |||
| c028ae3a44 | |||
| 2598dbc31a | |||
| f147a4cd48 | |||
| c3fa269b21 | |||
| 385ba8d1b7 | |||
| f4ccf911fa | |||
| 0cb8c92fe4 | |||
| bc68651815 | |||
| d1f50babaa | |||
| 3316301693 | |||
| feedababd2 | |||
| 480ee3299f | |||
| 2d1fb0f508 | |||
| b1a55b0666 | |||
| 24af996f82 | |||
| 8d7eec79c8 | |||
| ccced0c9fc | |||
| 4166eeb7da | |||
| 1f93a74d8c | |||
| b16e2f25f7 | |||
| 9cc841c674 | |||
| 63c28ea395 | |||
| 98c33a4748 | |||
| 4428248a01 | |||
| 7d6f113072 | |||
| 7ac05c838d | |||
| c85f1692d6 | |||
| 9fd329713a | |||
| 97d068e5a2 | |||
| e5bea36387 | |||
| cf1d8c3d5b | |||
| 464b65cfb0 | |||
| 90145426b4 | |||
| c76bc4cdea | |||
| 20f0381f81 | |||
| a447c652cb | |||
| 8277dbf0dc | |||
| eb0918249d | |||
| 640a7889fc | |||
| 03c6ee5f9a | |||
| dfd229ae4f | |||
| aba42c805f | |||
| 8b6b41f8dc | |||
| 1771da222b | |||
| 0514616c87 | |||
| f15872293d | |||
| a97255e3d1 | |||
| 1716d599c1 | |||
| c07ab7e1fa | |||
| 5ba9fbd9ca | |||
| 38b814f3d4 | |||
| 48a963793b | |||
| 9833b84bf8 | |||
| 27eeff7535 | |||
| 202a493c14 | |||
| eadd4c0856 | |||
| 3434a5d5df | |||
| 1ba51a6d02 | |||
| c62ca6c5d2 | |||
| 4831195310 | |||
| c514d9ffe2 | |||
| 9ae4477356 | |||
| 0e545e5177 | |||
| a0c9a7d85d | |||
| 9ce6dd9e25 | |||
| 51bd288f1a | |||
| fc6262e23d | |||
| d2b16afb12 | |||
| a754c86f64 | |||
| 76e6dc1ba1 | |||
| d10d3ef251 | |||
| feebca050a | |||
| a8e7a2967c | |||
| 2cf509795e | |||
| d3846b0beb | |||
| 08d2ed8015 | |||
| 4bcd14b8de | |||
| c34935090d | |||
| 9cfd56587e | |||
| ff8584a025 | |||
| 6bc1e5186a | |||
| 69dc8165ae | |||
| 021bca2ad9 | |||
| 4e0ee0d643 | |||
| 0a8aa85871 | |||
| 76ddd8b948 | |||
| bf08733068 | |||
| e38f56c071 | |||
| 19fe69dac0 | |||
| 14319ee608 | |||
| 9b04fd25b6 | |||
| 40e98ba690 | |||
| 894d65d58a | |||
| f58d508df2 | |||
| e22b909e7c | |||
| 09f1673cbf | |||
| 4744f99990 | |||
| 01c1735739 | |||
| 6808a42455 | |||
| fff719cb4f | |||
| e2c00f6ed8 | |||
| 0f90db23c5 | |||
| 96b192f2ae | |||
| ecdc34a699 | |||
| fa6a2fb9b7 | |||
| b011643dc9 | |||
| 30c10c1c6e | |||
| 56e2360072 |
@@ -0,0 +1,140 @@
|
||||
# Streaming Video Encoding — Encode on the fly during recording
|
||||
|
||||
## Problem
|
||||
|
||||
After each episode, `save_episode()` blocks for **~79 seconds** on a 3-camera setup (3197 frames, 107s episode):
|
||||
|
||||
| Step | Time |
|
||||
|------|------|
|
||||
| Write 9591 PNGs to disk | ~19s |
|
||||
| Read PNGs back → compute image stats | ~15s |
|
||||
| Read PNGs again → encode 3× AV1 videos → delete PNGs | ~44.5s |
|
||||
| Save parquet + metadata | ~0.6s |
|
||||
| **Total** | **~79s** |
|
||||
|
||||
The entire pipeline writes frames as temporary PNGs, reads them back twice (stats + encoding), then deletes them. This round-trip is the bottleneck.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Before: sequential post-episode pipeline
|
||||
|
||||
```
|
||||
Recording loop save_episode() — BLOCKS ~79s
|
||||
┌─────────────┐ ┌──────────────────────────────────────────────────────────┐
|
||||
│ 30fps loop │ │ │
|
||||
│ │ frames │ frame_buffer ──► write PNGs ──► read PNGs ──► stats │
|
||||
│ camera ─►───┼──► list │ (~19s) │ (~15s) │
|
||||
│ teleop │ │ ▼ │
|
||||
│ policy │ │ read PNGs ──► AV1 encode ──► delete PNGs │
|
||||
│ │ │ (~44.5s) │
|
||||
└──────┬───────┘ └──────────────────────────────────────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
episode ends next episode
|
||||
(~107s recording) (~79s blocked)
|
||||
```
|
||||
|
||||
**Data path:** `frame → list → PNG disk → read → stats` + `PNG disk → read → encode → MP4 → delete PNGs`
|
||||
|
||||
### After: streaming pipeline (encodes during recording)
|
||||
|
||||
```
|
||||
Recording loop (encoding happens HERE) save_episode() — ~0.5s
|
||||
┌───────────────────────────────────────┐ ┌──────────────────┐
|
||||
│ 30fps control loop │ │ │
|
||||
│ │ │ flush encoders │
|
||||
│ camera ──► frame ─┬─► queue ──► [T1] ├── AV1 ─┤ (already done) │
|
||||
│ │ queue ──► [T2] ├── AV1 ─┤ ~0.16s │
|
||||
│ │ queue ──► [T3] ├── AV1 ─┤ │
|
||||
│ │ │ │ running stats │
|
||||
│ └─► downsample ──► │─ stats ─┤ → finalize │
|
||||
│ RunningQuantile │ │ ~0.01s │
|
||||
│ teleop / policy (never blocked) │ │ │
|
||||
└───────────────────────────────────────┘ │ save parquet │
|
||||
│ ~0.36s │
|
||||
[T1] [T2] [T3] = encoder threads └──────────────────┘
|
||||
(one per camera, GIL released by PyAV)
|
||||
```
|
||||
|
||||
**Data path:** `frame → queue → encode → MP4` (zero PNGs, zero re-reads)
|
||||
|
||||
## Stats computation changes
|
||||
|
||||
| | Before | After |
|
||||
|---|---|---|
|
||||
| **Method** | `compute_episode_stats()` reads all PNGs from disk, decodes them, computes min/max/mean/std/quantiles | `RunningQuantileStats` accumulates stats incrementally per frame during recording |
|
||||
| **Input** | Full-resolution PNGs read back from disk | Downsampled frames (via `auto_downsample_height_width`, ~150×100px) directly from memory |
|
||||
| **When** | After episode ends, inside `save_episode()` | During recording, inside `add_frame()` (~2ms per frame) |
|
||||
| **Output** | `{mean, std, min, max, q01..q99}` shaped `(C,1,1)` in `[0,1]` | Identical shape and scale — `RunningQuantileStats.get_statistics()` → reshape `(C,1,1)` / 255 |
|
||||
| **I/O** | Reads 9591 PNGs (~15s) | Zero disk I/O |
|
||||
| **Numeric features** | Computed from episode buffer (unchanged) | Computed from episode buffer (unchanged) |
|
||||
|
||||
The running stats use the same `auto_downsample_height_width` function and produce the same statistical keys (`mean`, `std`, `min`, `max`, `count`, `q01`, `q10`, `q50`, `q90`, `q99`). Video features are excluded from the post-episode `compute_episode_stats()` call when streaming is active — only numeric features go through that path.
|
||||
|
||||
## Results
|
||||
|
||||
Tested on the same 3-camera setup (2028 frames, 67.6s episode):
|
||||
|
||||
| Step | Before | After | Speedup |
|
||||
|------|--------|-------|---------|
|
||||
| Frame writing (PNGs) | ~19s | **0s** | ∞ (eliminated) |
|
||||
| Episode stats | ~15s | **0.01s** | 1500× |
|
||||
| Video encoding | ~44.5s | **0.16s** | 278× |
|
||||
| Parquet + meta | ~0.6s | **0.36s** | ~same |
|
||||
| **Total `save_episode()`** | **~79s** | **0.55s** | **143×** |
|
||||
|
||||
The video encoding time drops to near-zero because most encoding already happened during recording. `finish_episode()` only flushes the last few buffered frames.
|
||||
|
||||
### Per-frame overhead during recording
|
||||
|
||||
| Operation | Time |
|
||||
|-----------|------|
|
||||
| `queue.put(frame)` (non-blocking) | ~0.01ms |
|
||||
| `auto_downsample_height_width` | ~0.5ms |
|
||||
| `RunningQuantileStats.update` | ~1ms |
|
||||
| **Total per frame** | **~2ms** (well within 33ms budget at 30fps) |
|
||||
|
||||
## Usage
|
||||
|
||||
Streaming is **on by default**. Users on weaker PCs can disable it to fall back to the old post-episode pipeline:
|
||||
|
||||
```bash
|
||||
# Default (streaming ON)
|
||||
lerobot-record --dataset.repo_id=user/dataset ...
|
||||
|
||||
# Old behavior (streaming OFF)
|
||||
lerobot-record --dataset.repo_id=user/dataset --dataset.streaming_encoding=false
|
||||
```
|
||||
|
||||
For the RaC data collection script, set `streaming_encoding: false` in the dataset config.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### `src/lerobot/datasets/video_utils.py`
|
||||
- Added `StreamingVideoEncoder` — manages one `_CameraEncoder` thread per camera
|
||||
- Added `_CameraEncoder` — daemon thread that reads frames from a queue and encodes with PyAV
|
||||
- Non-blocking unbounded queue ensures the control loop is never delayed
|
||||
|
||||
### `src/lerobot/datasets/lerobot_dataset.py`
|
||||
- `create()` / `start_streaming_encoder()`: new `streaming_encoding` parameter
|
||||
- `add_frame()`: when streaming, feeds frames to encoder + accumulates running stats instead of writing PNGs
|
||||
- `save_episode()`: when streaming, uses running stats and calls `finish_episode()` to get already-encoded video paths
|
||||
- `clear_episode_buffer()`: cancels in-progress encoding on re-record
|
||||
- `finalize()`: cleans up encoder on shutdown
|
||||
- **Full backward compatibility**: when `streaming_encoding=False`, all existing code paths are unchanged
|
||||
|
||||
### `src/lerobot/scripts/lerobot_record.py`
|
||||
- Added `streaming_encoding: bool = True` to `DatasetRecordConfig`
|
||||
- Wired through to both `create()` and `resume` paths
|
||||
|
||||
### `examples/rac/rac_data_collection_openarms_rtc.py`
|
||||
- Added `streaming_encoding: bool = True` to `RaCRTCDatasetConfig`
|
||||
- Frames are added inline during the control loop (streaming) or buffered for post-loop writing (old path)
|
||||
- Automatically detects mode and adjusts behavior
|
||||
|
||||
## Design Notes
|
||||
|
||||
- **Why threads, not processes?** PyAV/FFmpeg releases the GIL during encoding. Threads share memory (zero-copy frame passing), avoiding the serialization overhead of multiprocessing.
|
||||
- **Why unbounded queue?** At 30fps production vs ~72fps encoding throughput, the queue stays near-empty. Even during brief encoder stalls, memory growth is bounded by episode length. The control loop must never block.
|
||||
- **Why running stats?** Avoids the expensive read-back-from-disk step. `RunningQuantileStats` + `auto_downsample_height_width` compute identical statistics incrementally with ~2ms overhead per frame.
|
||||
- **Backward compatible**: Setting `streaming_encoding=false` restores the original PNG → encode pipeline exactly. No behavior changes for existing users who don't opt in.
|
||||
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -57,6 +57,8 @@
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
- local: training_time_rtc
|
||||
title: Training-Time RTC
|
||||
title: "Inference"
|
||||
- sections:
|
||||
- local: envhub
|
||||
|
||||
@@ -185,7 +185,7 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
|
||||
@@ -224,7 +224,7 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
@@ -241,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,7 +270,7 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
# OpenArms Robot
|
||||
|
||||
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
|
||||
|
||||
## Hardware Overview
|
||||
|
||||
- **7 DOF per arm** (14 DOF total for dual arm setup)
|
||||
- **1 gripper per arm** (2 grippers total)
|
||||
- **Damiao motors** with 4 different types:
|
||||
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
|
||||
- **DM4340** for shoulder rotation and elbow (J3, J4)
|
||||
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
|
||||
- **24V power supply** required
|
||||
- **CAN interface device**:
|
||||
- **Linux**: Any SocketCAN-compatible adapter
|
||||
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
|
||||
- Proper CAN wiring (CANH, CANL, 120Ω termination)
|
||||
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
|
||||
|
||||
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|
||||
|-------|-------|------------|---------------|-------------|-------------|
|
||||
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
|
||||
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
|
||||
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
|
||||
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
|
||||
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
|
||||
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
|
||||
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
|
||||
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
|
||||
|
||||
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install system dependencies
|
||||
sudo apt install can-utils iproute2
|
||||
|
||||
# Install LeRobot with OpenArms support
|
||||
pip install -e ".[openarms]"
|
||||
```
|
||||
|
||||
## Setup Guide
|
||||
|
||||
### Step 1: Motor ID Configuration
|
||||
|
||||
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
|
||||
|
||||
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
|
||||
|
||||
Key points:
|
||||
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
|
||||
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
|
||||
- Use the Damiao Debugging Tools to set these IDs
|
||||
|
||||
### Step 2: Setup CAN Interface
|
||||
|
||||
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
|
||||
|
||||
#### Linux (SocketCAN)
|
||||
|
||||
```bash
|
||||
# Find your CAN interface
|
||||
ip link show
|
||||
|
||||
# Configure can0, 1, 2, 3
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
|
||||
sudo ip link set can1 down
|
||||
sudo ip link set can1 type can bitrate 1000000
|
||||
sudo ip link set can1 up
|
||||
|
||||
sudo ip link set can2 down
|
||||
sudo ip link set can2 type can bitrate 1000000
|
||||
sudo ip link set can2 up
|
||||
|
||||
sudo ip link set can3 down
|
||||
sudo ip link set can3 type can bitrate 1000000
|
||||
sudo ip link set can3 up
|
||||
|
||||
# Verify configuration
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
or run:
|
||||
|
||||
`examples/openarms/setup_can.sh`
|
||||
|
||||
### Testing canbus and motor connection
|
||||
|
||||
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarms import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
# Configure for dual arm setup
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
can_interface="socketcan", # Or "auto" for auto-detection
|
||||
id="openarms_dual",
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
robot = OpenArmsFollower(config)
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
On first use, you'll need to calibrate the robot:
|
||||
|
||||
```python
|
||||
robot.calibrate()
|
||||
```
|
||||
|
||||
The calibration process will:
|
||||
1. Disable torque on all motors
|
||||
2. Ask you to position arms in **hanging position with grippers closed**
|
||||
3. Set this as the zero position
|
||||
4. Ask you to move each joint through its full range
|
||||
5. Record min/max positions for each joint
|
||||
6. Save calibration to file
|
||||
|
||||
### Reading Observations
|
||||
|
||||
The robot provides comprehensive state information:
|
||||
|
||||
```python
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Observation includes for each motor:
|
||||
# - {motor_name}.pos: Position in degrees
|
||||
# - {motor_name}.vel: Velocity in degrees/second
|
||||
# - {motor_name}.torque: Motor torque
|
||||
# - {camera_name}: Camera images (if configured)
|
||||
|
||||
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
|
||||
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
|
||||
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
|
||||
```
|
||||
|
||||
### Sending Actions
|
||||
|
||||
```python
|
||||
# Send target positions (in degrees)
|
||||
action = {
|
||||
"right_joint_1.pos": 45.0,
|
||||
"right_joint_2.pos": -30.0,
|
||||
# ... all joints
|
||||
"right_gripper.pos": 45.0, # Half-closed
|
||||
}
|
||||
|
||||
actual_action = robot.send_action(action)
|
||||
```
|
||||
|
||||
### Gripper Control
|
||||
|
||||
```python
|
||||
# Open gripper
|
||||
robot.open_gripper(arm="right")
|
||||
|
||||
# Close gripper
|
||||
robot.close_gripper(arm="right")
|
||||
```
|
||||
|
||||
## Safety Features
|
||||
|
||||
### 1. Maximum Relative Target
|
||||
|
||||
Limits how far a joint can move in a single command to prevent sudden movements:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Limit all joints to 10 degrees per command
|
||||
max_relative_target=10.0,
|
||||
|
||||
# Or set per-motor limits
|
||||
max_relative_target={
|
||||
"right_joint_1": 15.0, # Slower moving joint
|
||||
"right_joint_2": 10.0,
|
||||
"right_gripper": 5.0, # Very slow gripper
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
|
||||
|
||||
### 2. Torque Limits
|
||||
|
||||
Control maximum torque output, especially important for grippers and teleoperation:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Gripper torque limit (fraction of motor's max torque)
|
||||
gripper_torque_limit=0.5, # 50% of max torque
|
||||
)
|
||||
```
|
||||
|
||||
Lower torque limits prevent damage when gripping delicate objects.
|
||||
|
||||
### 3. MIT Control Gains
|
||||
|
||||
Control responsiveness and stability via PID-like gains:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
position_kp=10.0, # Position gain (higher = more responsive)
|
||||
position_kd=0.5, # Velocity damping (higher = more damped)
|
||||
)
|
||||
```
|
||||
|
||||
**Guidelines**:
|
||||
- **For following (robot)**: Higher gains for responsiveness
|
||||
- `position_kp=10.0`, `position_kd=0.5`
|
||||
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
|
||||
- `manual_control=True` (torque disabled)
|
||||
|
||||
### 4. Velocity Limits
|
||||
|
||||
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
|
||||
- Max velocity: 30 rad/s ≈ 1718°/s
|
||||
|
||||
The motors will automatically limit velocity to safe values.
|
||||
|
||||
## Teleoperation
|
||||
|
||||
### Leader Arm Setup
|
||||
|
||||
The leader arm is moved manually (torque disabled) to generate commands:
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarms import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
config = OpenArmsLeaderConfig(
|
||||
port="can1", # Separate CAN interface for leader
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Torque disabled for manual movement
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position as action
|
||||
action = leader.get_action()
|
||||
# action contains positions for all joints in degrees
|
||||
```
|
||||
|
||||
### Safety Considerations for Teleoperation
|
||||
|
||||
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
|
||||
2. **Enable max_relative_target** on follower to smooth abrupt movements
|
||||
3. **Lower torque limits** on follower to prevent damage from tracking errors
|
||||
4. **Test with one arm** before enabling dual arm teleoperation
|
||||
5. **Have emergency stop** ready (power switch or CAN disable)
|
||||
|
||||
```python
|
||||
# Recommended follower config for teleoperation
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
max_relative_target=5.0, # Small steps for smooth following
|
||||
gripper_torque_limit=0.3, # Low torque for safety
|
||||
position_kp=5.0, # Lower gains for gentler following
|
||||
position_kd=0.3,
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Motor Shaking/Unstable
|
||||
|
||||
- **Lower control gains**: Reduce `position_kp` and `position_kd`
|
||||
- **Check calibration**: Re-run calibration procedure
|
||||
- **Verify power**: Insufficient current can cause instability
|
||||
- **Check mechanical**: Loose connections, binding, or damaged components
|
||||
|
||||
### CAN Bus Errors
|
||||
|
||||
```bash
|
||||
# Check for errors
|
||||
ip -s link show can0
|
||||
|
||||
# Reset CAN interface
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Control Mode
|
||||
|
||||
OpenArms uses **MIT control mode** which allows simultaneous control of:
|
||||
- Position (degrees)
|
||||
- Velocity (degrees/second)
|
||||
- Torque (N·m)
|
||||
- Position gain (Kp)
|
||||
- Velocity damping (Kd)
|
||||
|
||||
### Communication
|
||||
|
||||
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
|
||||
- **Frame format**: Standard 11-bit IDs
|
||||
- **Update rate**: Typically 50-100 Hz depending on motor count
|
||||
- **Latency**: ~10-20ms per motor command
|
||||
|
||||
## References
|
||||
|
||||
- [OpenArm Official Documentation](https://docs.openarm.dev/)
|
||||
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
|
||||
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
|
||||
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
|
||||
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
|
||||
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
|
||||
- [Enactic GitHub](https://github.com/enactic/openarm_can)
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
# RaC: Recovery and Correction Training
|
||||
|
||||
RaC (Recovery and Correction) is a human-in-the-loop data collection and training paradigm that improves robot policy performance on long-horizon tasks by explicitly teaching recovery and correction behaviors.
|
||||
|
||||
**Key References:**
|
||||
- [RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction](https://arxiv.org/abs/2509.07953) (Hu et al., 2025)
|
||||
- [HG-DAgger: Interactive Imitation Learning with Human Experts](https://arxiv.org/abs/1810.02890) (Kelly et al., 2019)
|
||||
- [π∗0.6: a VLA That Learns From Experience](https://pi.website/blog/pistar06) (Physical Intelligence, 2025)
|
||||
- [SARM: Stage-Aware Reward Modeling](https://arxiv.org/abs/2509.25358) (Chen et al., 2025)
|
||||
|
||||
---
|
||||
|
||||
## Why RaC? The Problem with Standard Data Collection
|
||||
|
||||
### Standard Behavioral Cloning Data Collection Limitations
|
||||
|
||||
Standard behavior cloning trains policies on successful demonstrations. This approach can be sensitive to distribution shift and compounding errors. Because during deployment small errors can cascade and push the robot into states never seen during training.
|
||||
This is where RaC and methods like Dagger and HG-DAgger come in.
|
||||
|
||||
### Prior Human-in-the-Loop Methods
|
||||
|
||||
**DAgger** (Dataset Aggregation) addresses distribution shift by:
|
||||
- Running the novice policy to collect states
|
||||
- Querying expert for correct actions at those states
|
||||
- Aggregating new labels into training set
|
||||
|
||||
**HG-DAgger** (Human-Gated DAgger) improves on DAgger by:
|
||||
- Giving human full control authority during interventions
|
||||
- Human takes over when unsafe, provides correction, returns control
|
||||
- Better action labels because human has uninterrupted control
|
||||
|
||||
### RaC
|
||||
|
||||
RaC explicitly collects **recovery + correction** data:
|
||||
|
||||
```
|
||||
BC/DAgger: policy → mistake → human corrects → continue
|
||||
RaC: policy → mistake → human RECOVERS (teleop back) → CORRECTS → END
|
||||
```
|
||||
|
||||
The critical insight is **Rule 1 (Recover then Correct)**:
|
||||
- Every intervention starts with human teleoperating back to an in-distribution state
|
||||
- Then human provides correction to complete the current subtask
|
||||
- Both segments are recorded as training data
|
||||
- This teaches the policy: "when things go wrong, go back and retry"
|
||||
|
||||
**Rule 2 (Terminate after Intervention)**:
|
||||
- Episode ends after correction completes
|
||||
- Avoids mixed policy/human data on later subtasks
|
||||
- Keeps data distribution clean
|
||||
|
||||
---
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Method | Data Type | Recovery Behavior | Correction Behavior |
|
||||
|--------|-----------|-------------------|---------------------|
|
||||
| BC | Success only | ✗ | ✗ |
|
||||
| DAgger | Success + corrections | ✗ | ✓ |
|
||||
| HG-DAgger | Success + corrections | Sometimes | ✓ |
|
||||
| RaC | Success + recovery + correction | ✓ Explicit | ✓ |
|
||||
|
||||
---
|
||||
|
||||
## The RaC Pipeline
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ RaC Training Pipeline │
|
||||
├─────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 1. PRE-TRAINING (Standard BC) │
|
||||
│ └─> Train initial policy on clean demonstrations │
|
||||
│ │
|
||||
│ 2. RAC DATA COLLECTION (Human-in-the-loop) │
|
||||
│ ├─> Policy runs autonomously │
|
||||
│ ├─> Human monitors and intervenes when failure imminent │
|
||||
│ │ ├─> RECOVERY: Human teleoperates robot back to good state │
|
||||
│ │ └─> CORRECTION: Human completes the current subtask │
|
||||
│ └─> Episode terminates after correction (Rule 2) │
|
||||
│ │
|
||||
│ 3. REWARD LABELING (Optional: SARM) │
|
||||
│ └─> Compute progress rewards for advantage-weighted training │
|
||||
│ │
|
||||
│ 4. FINE-TUNING │
|
||||
│ └─> Train on combined demos + RaC data (optionally with RA-BC) │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect RaC Data
|
||||
|
||||
Run the RaC data collection script with your pre-trained policy:
|
||||
|
||||
```bash
|
||||
python examples/rac/rac_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--dataset.single_task="Pick up the cube and place it in the bowl" \
|
||||
--dataset.num_episodes=50
|
||||
```
|
||||
|
||||
**Controls (Keyboard + Foot Pedal):**
|
||||
|
||||
| Key / Pedal | Action |
|
||||
|-------------|--------|
|
||||
| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) |
|
||||
| **c** / Left pedal | Take control (start correction, recording resumes) |
|
||||
| **→** / Right pedal | End episode (save) - when in correction mode |
|
||||
| **←** | Re-record episode |
|
||||
| **ESC** | Stop session and push to hub |
|
||||
| Any key/pedal during reset | Start next episode |
|
||||
|
||||
**The RaC Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, press **SPACE** or **right pedal** to pause
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Press **c** or **left pedal** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **RECOVERY**: Teleoperate back to a good state
|
||||
- **CORRECTION**: Complete the subtask
|
||||
- All movements are recorded
|
||||
4. Press **→** or **right pedal** to save and end episode
|
||||
5. **RESET**: Teleop moves to robot position, you can move robot to starting position
|
||||
6. Press any key/pedal to start next episode
|
||||
|
||||
The recovery and correction segments teach the policy how to recover from errors.
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: (Optional) Compute SARM Rewards
|
||||
|
||||
For advantage-weighted training (RA-BC / Pi0.6-style), compute SARM progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/rac-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Step 4: Fine-tune Policy
|
||||
|
||||
Fine-tune on the RaC data:
|
||||
|
||||
```bash
|
||||
# Without RA-BC (standard fine-tuning)
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/rac_finetune \
|
||||
--steps=20000
|
||||
|
||||
# With RA-BC (advantage-weighted, Pi0.6-style)
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/rac_finetune_rabc \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Connection to Pi0.6 / RECAP
|
||||
|
||||
Pi0.6's RECAP method shares similar principles:
|
||||
- Collect autonomous rollouts + expert interventions
|
||||
- Use value function to compute **advantages**: A(s,a) = V(s') - V(s)
|
||||
- **Advantage conditioning**: Weight training based on expected improvement
|
||||
|
||||
In LeRobot, we can use **SARM** as the value function:
|
||||
- SARM progress φ(s) ∈ [0,1] measures task completion
|
||||
- Progress delta = φ(s') - φ(s) approximates advantage
|
||||
- RA-BC uses these to weight training samples (higher weight for good corrections)
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective RaC Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Iterative Improvement
|
||||
|
||||
RaC can be applied iteratively:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (demos) │
|
||||
│ ↓ │
|
||||
│ RaC Collection (target current failure modes) → Policy v1 │
|
||||
│ ↓ │
|
||||
│ RaC Collection (target new failure modes) → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Each iteration:
|
||||
1. Deploy current policy
|
||||
2. Collect RaC interventions on failure cases
|
||||
3. Fine-tune on accumulated data
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
```bibtex
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π∗0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ lerobot-train \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ lerobot-train \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
# Training-Time RTC
|
||||
|
||||
Training-Time RTC teaches the model to handle inference delay during training.
|
||||
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
|
||||
This keeps chunk transitions smooth without doing any inference-time inpainting.
|
||||
|
||||
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
|
||||
|
||||
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### At Training Time
|
||||
|
||||
- Sample a delay `d` per batch element.
|
||||
- Keep the first `d` action steps as **ground truth** (no noise).
|
||||
- Add noise only to the postfix actions.
|
||||
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
|
||||
- Mask the loss to only train on the postfix.
|
||||
|
||||
### At Inference Time
|
||||
|
||||
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
|
||||
|
||||
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
|
||||
- Set timestep to **1.0** for prefix positions.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your/dataset \
|
||||
--policy.rtc_training_config.enabled=true \
|
||||
--policy.rtc_training_config.min_delay=0 \
|
||||
--policy.rtc_training_config.max_delay=6 \
|
||||
--policy.rtc_training_config.delay_distribution=UNIFORM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference with Training-Time RTC
|
||||
|
||||
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
|
||||
|
||||
```python
|
||||
policy = PI0Policy.from_pretrained("path/to/trained/model")
|
||||
# rtc_training_config is loaded from the saved config
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
inference_delay=5, # estimated delay in timesteps
|
||||
prev_chunk_left_over=previous_actions, # from previous chunk
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
|
||||
|
||||
- **`enabled`**: Toggle training-time RTC (both training and inference).
|
||||
- **`min_delay` / `max_delay`**: Delay range (inclusive).
|
||||
- **`delay_distribution`**:
|
||||
- `UNIFORM`: uniform in `[min_delay, max_delay]`
|
||||
- `EXP`: exponentially decayed distribution over delays
|
||||
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
|
||||
|
||||
---
|
||||
|
||||
## Notes and Recommendations
|
||||
|
||||
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
|
||||
- Use `EXP` if you want more supervision on smaller delays.
|
||||
|
||||
---
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
||||
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -266,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
|
||||
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive debug script for OpenArms CAN FD communication.
|
||||
Tests all 4 CAN interfaces with CAN FD support.
|
||||
"""
|
||||
|
||||
import can
|
||||
import time
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def check_can_interface(port):
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(['ip', 'link', 'show', port],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", None
|
||||
|
||||
output = result.stdout
|
||||
if 'UP' not in output:
|
||||
return False, "Interface is DOWN", None
|
||||
|
||||
# Check if CAN FD is enabled
|
||||
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
|
||||
|
||||
return True, "Interface is UP", is_fd
|
||||
except FileNotFoundError:
|
||||
return None, "Cannot check (ip command not found)", None
|
||||
|
||||
|
||||
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
|
||||
"""
|
||||
Test a single motor and return all responses.
|
||||
|
||||
Returns:
|
||||
list of (arbitration_id, data) tuples for all responses received
|
||||
"""
|
||||
# Send enable command
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
# Listen for responses
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
|
||||
|
||||
# Send disable command
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(port, interface_type="socketcan", use_can_fd=True):
|
||||
"""Test all 8 motors on a single CAN interface."""
|
||||
|
||||
results = {
|
||||
'interface': port,
|
||||
'status': None,
|
||||
'is_fd': use_can_fd,
|
||||
'motors': {}
|
||||
}
|
||||
|
||||
# Check interface status
|
||||
status_ok, status_msg, interface_has_fd = check_can_interface(port)
|
||||
|
||||
if interface_has_fd is not None:
|
||||
results['interface_fd_enabled'] = interface_has_fd
|
||||
if use_can_fd and not interface_has_fd:
|
||||
status_msg += " (CAN FD NOT enabled on interface!)"
|
||||
elif interface_has_fd:
|
||||
status_msg += " (CAN FD enabled)"
|
||||
|
||||
results['status'] = status_msg
|
||||
|
||||
if status_ok is False:
|
||||
return results
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
if use_can_fd:
|
||||
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000
|
||||
)
|
||||
except Exception as e:
|
||||
results['status'] = f"Connection failed: {e}"
|
||||
return results
|
||||
|
||||
try:
|
||||
# Clear any pending messages
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
# Test each motor (0x01 to 0x08)
|
||||
for motor_id in range(0x01, 0x09):
|
||||
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
|
||||
|
||||
if error:
|
||||
results['motors'][motor_id] = {'error': error}
|
||||
elif responses:
|
||||
results['motors'][motor_id] = {
|
||||
'found': True,
|
||||
'responses': responses
|
||||
}
|
||||
else:
|
||||
results['motors'][motor_id] = {
|
||||
'found': False,
|
||||
'responses': []
|
||||
}
|
||||
|
||||
time.sleep(0.05) # Small delay between motors
|
||||
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(all_results):
|
||||
"""Print formatted results for all interfaces."""
|
||||
|
||||
print("SUMMARY - Motors Found on Each Interface")
|
||||
|
||||
motor_names = {
|
||||
0x01: "joint_1 (Shoulder pan)",
|
||||
0x02: "joint_2 (Shoulder lift)",
|
||||
0x03: "joint_3 (Shoulder rotation)",
|
||||
0x04: "joint_4 (Elbow flex)",
|
||||
0x05: "joint_5 (Wrist roll)",
|
||||
0x06: "joint_6 (Wrist pitch)",
|
||||
0x07: "joint_7 (Wrist rotation)",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
total_found = 0
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
status = result['status']
|
||||
|
||||
print(f"{interface}: {status}")
|
||||
if result.get('is_fd'):
|
||||
print(f" Mode: CAN FD")
|
||||
else:
|
||||
print(f" Mode: CAN 2.0")
|
||||
|
||||
if 'Connection failed' in status or 'DOWN' in status:
|
||||
print(f" ⚠ Cannot test {interface}")
|
||||
continue
|
||||
|
||||
motors_found = 0
|
||||
|
||||
for motor_id in range(0x01, 0x09):
|
||||
motor_data = result['motors'].get(motor_id, {})
|
||||
motor_name = motor_names.get(motor_id, "Unknown")
|
||||
|
||||
if motor_data.get('error'):
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
|
||||
elif motor_data.get('found'):
|
||||
motors_found += 1
|
||||
total_found += 1
|
||||
responses = motor_data['responses']
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
|
||||
for resp_id, data, is_fd in responses:
|
||||
data_hex = data.hex()
|
||||
fd_flag = " [FD]" if is_fd else " [2.0]"
|
||||
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
|
||||
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
|
||||
|
||||
# Overall summary
|
||||
print("OVERALL SUMMARY")
|
||||
print(f"Total motors found across all interfaces: {total_found}")
|
||||
|
||||
# Analyze configuration
|
||||
print("DIAGNOSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
|
||||
if motors_found == 0:
|
||||
print(f"\n⚠ {interface}: NO MOTORS FOUND")
|
||||
print(" Possible issues:")
|
||||
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
|
||||
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
|
||||
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
|
||||
print(" 4. CANH/CANL wiring issue")
|
||||
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
|
||||
|
||||
# Check FD mismatch
|
||||
if result.get('is_fd') and not result.get('interface_fd_enabled'):
|
||||
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
|
||||
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
|
||||
|
||||
elif motors_found < 8:
|
||||
print(f"\n⚠ {interface}: Only {motors_found}/8 motors responding")
|
||||
print(" Check power and connections for missing motors")
|
||||
else:
|
||||
print(f"\n✓ {interface}: All 8 motors responding correctly!")
|
||||
|
||||
# Check for unexpected response IDs
|
||||
print("RESPONSE ID ANALYSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
unexpected = []
|
||||
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
expected_id = motor_id + 0x10
|
||||
actual_ids = [resp[0] for resp in motor_data['responses']]
|
||||
|
||||
if expected_id not in actual_ids:
|
||||
unexpected.append((motor_id, actual_ids))
|
||||
|
||||
if unexpected:
|
||||
print(f"\n⚠ {interface}: Unexpected response IDs detected")
|
||||
for motor_id, actual_ids in unexpected:
|
||||
expected_id = motor_id + 0x10
|
||||
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
|
||||
f"got {[f'0x{id:02X}' for id in actual_ids]}")
|
||||
print(" → Motor Master IDs need reconfiguration")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
if motors_found > 0:
|
||||
print(f"\n✓ {interface}: All responding motors use correct IDs")
|
||||
|
||||
|
||||
def test_communication_speed(interface, motor_id, num_iterations=100):
|
||||
"""
|
||||
Test communication speed with a motor.
|
||||
|
||||
Returns:
|
||||
tuple: (hz, avg_latency_ms) or (None, None) if test failed
|
||||
"""
|
||||
try:
|
||||
# Connect to interface
|
||||
bus = can.interface.Bus(
|
||||
channel=interface,
|
||||
interface="socketcan",
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
|
||||
# Send refresh commands and measure round-trip time
|
||||
latencies = []
|
||||
successful = 0
|
||||
|
||||
for _ in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
# Send enable command (lightweight operation)
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=True
|
||||
)
|
||||
bus.send(enable_msg)
|
||||
|
||||
# Wait for response
|
||||
msg = bus.recv(timeout=0.1)
|
||||
|
||||
if msg:
|
||||
latency = (time.perf_counter() - start) * 1000 # Convert to ms
|
||||
latencies.append(latency)
|
||||
successful += 1
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if successful > 0:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
return hz, avg_latency
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
print(f" Speed test error: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to test all CAN interfaces with CAN FD."""
|
||||
|
||||
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
|
||||
print("Testing motors 0x01-0x08 on each interface")
|
||||
print()
|
||||
print("Make sure:")
|
||||
print(" ✓ Motors are powered (24V)")
|
||||
print(" ✓ CAN interfaces configured with FD mode:")
|
||||
print(" ./examples/openarms/setup_can.sh")
|
||||
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
|
||||
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
|
||||
print()
|
||||
|
||||
input("Press ENTER to start testing...")
|
||||
|
||||
# Test all 4 interfaces with CAN FD
|
||||
all_results = []
|
||||
|
||||
for i in range(4):
|
||||
interface = f"can{i}"
|
||||
print(f"Testing {interface}...")
|
||||
|
||||
result = test_interface(interface, use_can_fd=True)
|
||||
all_results.append(result)
|
||||
|
||||
# Quick status
|
||||
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
|
||||
print(f" ⚠ {interface}: {result['status']}")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
print(f" {interface}: {motors_found}/8 motors found")
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# Print detailed results
|
||||
print_results(all_results)
|
||||
|
||||
print("Testing Complete!")
|
||||
|
||||
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
|
||||
|
||||
if all_found == 0:
|
||||
print("\n⚠️ CRITICAL: No motors found on any interface!")
|
||||
print("\nTop issues to check:")
|
||||
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
|
||||
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
|
||||
print(" 3. Missing termination resistors")
|
||||
print("\nTry:")
|
||||
print(" a) Check motor parameters with Damiao Debugging Tools")
|
||||
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
|
||||
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
|
||||
else:
|
||||
# Run speed test on interfaces with motors
|
||||
print("COMMUNICATION SPEED TEST")
|
||||
print("\nTesting maximum communication frequency...")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
|
||||
# Find first responding motor
|
||||
responding_motor = None
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if responding_motor:
|
||||
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
|
||||
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
|
||||
|
||||
if hz:
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
print(f" ✓ Avg latency: {latency:.2f} ms")
|
||||
print(f" ✓ Commands per second: ~{int(hz)}")
|
||||
else:
|
||||
print(f" ✗ Speed test failed")
|
||||
else:
|
||||
print(f"\n{interface}: No motors found, skipping speed test")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nTesting interrupted by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot by running inference and recording
|
||||
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate.py
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
HF_MODEL_ID = "lerobot-data-collection/level1_rac2_100k" # TODO: Replace with your trained model
|
||||
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_raccc3" # TODO: Replace with your eval dataset name
|
||||
TASK_DESCRIPTION = "Fold the T-shirt properly" # TODO: Replace with your task, this should match!!
|
||||
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 1000
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# Robot CAN interfaces
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
# If enabled, you can manually reset the environment between evaluation episodes
|
||||
USE_LEADER_FOR_RESETS = False # Set to False if you don't want to use leader
|
||||
LEADER_LEFT_PORT = "can2"
|
||||
LEADER_RIGHT_PORT = "can3"
|
||||
|
||||
# Camera configuration
|
||||
CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=1280, height=720, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=1280, height=720, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video2", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
print("OpenArms Policy Evaluation")
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"Reset Duration: {RESET_TIME_SEC}s")
|
||||
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
port_right=FOLLOWER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
|
||||
leader = None
|
||||
if USE_LEADER_FOR_RESETS:
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left=LEADER_LEFT_PORT,
|
||||
port_right=LEADER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
leader.connect(calibrate=False)
|
||||
|
||||
if not leader.is_connected:
|
||||
raise RuntimeError("Leader robot failed to connect!")
|
||||
|
||||
# Enable gravity compensation
|
||||
if leader.pin_robot is not None:
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
|
||||
else:
|
||||
print(f"Leader connected but gravity compensation unavailable (no URDF)")
|
||||
|
||||
# Build default processors for action and observation
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Build dataset features from robot features and processors
|
||||
# For actions, only include positions (no velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if dataset already exists
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||
if dataset_path.exists():
|
||||
print(f"Evaluation dataset already exists at: {dataset_path}")
|
||||
print("This will append new episodes to the existing dataset.")
|
||||
choice = input(" Continue? (y/n): ").strip().lower()
|
||||
if choice != 'y':
|
||||
print(" Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
if leader:
|
||||
leader.disconnect()
|
||||
return
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
|
||||
# Load policy config from pretrained model and create policy using factory
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nRunning evaluation...")
|
||||
# Initialize keyboard listener and visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_evaluation")
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\nRunning inference for episode {episode_idx + 1}...")
|
||||
|
||||
# Run inference with policy
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Save episode
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Reset environment between episodes (if not last episode)
|
||||
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
|
||||
if USE_LEADER_FOR_RESETS and leader:
|
||||
log_say("Reset the environment using leader arms")
|
||||
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
|
||||
|
||||
# Use leader for manual reset with gravity compensation
|
||||
import numpy as np
|
||||
|
||||
dt = 1 / FPS
|
||||
reset_start_time = time.perf_counter()
|
||||
|
||||
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
|
||||
if events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity and friction torques
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=1.0
|
||||
)
|
||||
|
||||
# Combine torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply compensation
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower
|
||||
follower_action = {}
|
||||
for joint in leader_positions_deg.keys():
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
print("Reset complete")
|
||||
else:
|
||||
log_say("Waiting for manual reset")
|
||||
print(f"Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
print(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
if leader:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
|
||||
follower.disconnect()
|
||||
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
print("\nUploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation with Real-Time Chunking (RTC)
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot using RTC for smooth, continuous motion.
|
||||
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
|
||||
despite high inference latency by asynchronously generating action chunks.
|
||||
|
||||
Features:
|
||||
- Thread-based asynchronous action generation and execution
|
||||
- RTC for smooth transitions between action chunks
|
||||
- Dataset recording for evaluation episodes
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate_with_rtc.py
|
||||
|
||||
# With custom RTC parameters
|
||||
python examples/openarms/evaluate_with_rtc.py \
|
||||
--rtc.execution_horizon=12 \
|
||||
--rtc.max_guidance_weight=10.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Default Configuration Constants
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_HF_MODEL_ID = "lerobot-data-collection/level1_rac3_100k"
|
||||
DEFAULT_HF_EVAL_DATASET_ID = "lerobot-data-collection/test"
|
||||
DEFAULT_TASK_DESCRIPTION = "Fold the T-shirt properly"
|
||||
|
||||
DEFAULT_NUM_EPISODES = 1
|
||||
DEFAULT_FPS = 30
|
||||
DEFAULT_EPISODE_TIME_SEC = 1000
|
||||
DEFAULT_RESET_TIME_SEC = 60
|
||||
|
||||
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
|
||||
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
DEFAULT_CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=1280, height=720, fps=DEFAULT_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video4", width=1280, height=720, fps=DEFAULT_FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video2", width=640, height=480, fps=DEFAULT_FPS),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Thread-Safe Robot Wrapper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: OpenArmsFollower):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmsRTCEvalConfig(HubMixin):
|
||||
"""Configuration for OpenArms evaluation with RTC."""
|
||||
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
)
|
||||
)
|
||||
|
||||
model_id: str = DEFAULT_HF_MODEL_ID
|
||||
eval_dataset_id: str = DEFAULT_HF_EVAL_DATASET_ID
|
||||
task: str = DEFAULT_TASK_DESCRIPTION
|
||||
|
||||
num_episodes: int = DEFAULT_NUM_EPISODES
|
||||
fps: float = DEFAULT_FPS
|
||||
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
||||
reset_time_sec: float = DEFAULT_RESET_TIME_SEC
|
||||
|
||||
follower_left_port: str = DEFAULT_FOLLOWER_LEFT_PORT
|
||||
follower_right_port: str = DEFAULT_FOLLOWER_RIGHT_PORT
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
# Should be higher than inference_delay + execution_horizon
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
record_dataset: bool = True
|
||||
push_to_hub: bool = True
|
||||
|
||||
interpolation: bool = True
|
||||
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
torch_compile_disable_cudagraphs: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.model_id = policy_path
|
||||
elif self.model_id:
|
||||
self.policy = PreTrainedConfig.from_pretrained(self.model_id)
|
||||
self.policy.pretrained_path = self.model_id
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Generation Thread
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_actions_thread(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCEvalConfig,
|
||||
episode_active: Event,
|
||||
):
|
||||
"""Thread function to asynchronously generate action chunks from the policy."""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting action generation thread")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.fps
|
||||
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
hw_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = robot.name
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
|
||||
"Should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
|
||||
f"delay={new_delay}, queue_size={action_queue.qsize()}"
|
||||
)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[GET_ACTIONS] Action generation thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Execution Thread
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def actor_thread(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCEvalConfig,
|
||||
episode_active: Event,
|
||||
dataset: LeRobotDataset | None,
|
||||
dataset_lock: Lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
):
|
||||
"""Thread function to execute actions on the robot."""
|
||||
try:
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
if cfg.interpolation:
|
||||
interp_factor = 2
|
||||
robot_interval = 1.0 / (cfg.fps * interp_factor)
|
||||
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.fps * interp_factor}Hz (2x)")
|
||||
else:
|
||||
interp_factor = 1
|
||||
robot_interval = 1.0 / cfg.fps
|
||||
logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz")
|
||||
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = time.perf_counter()
|
||||
last_dataset_time = 0.0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
prev_action = None
|
||||
interpolated_actions = []
|
||||
interp_idx = 0
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = time.perf_counter()
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
policy_consume_count += 1
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
action_dict[key] = action_to_send[i].item()
|
||||
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
robot_send_count += 1
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
now = time.perf_counter()
|
||||
if now - last_dataset_time >= (1.0 / cfg.fps):
|
||||
last_dataset_time = now
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_hz_print >= 5.0:
|
||||
elapsed = now - last_hz_print
|
||||
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
|
||||
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}")
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = now
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, robot_interval - dt_s - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info("[ACTOR] Shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Evaluation Function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: OpenArmsRTCEvalConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method."""
|
||||
if policy.name in ["pi05", "pi0"]:
|
||||
return policy
|
||||
|
||||
try:
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: OpenArmsRTCEvalConfig):
|
||||
"""Main evaluation function with RTC."""
|
||||
init_logging()
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Policy Evaluation with RTC")
|
||||
print("=" * 60)
|
||||
print(f"\nModel: {cfg.model_id}")
|
||||
print(f"Evaluation Dataset: {cfg.eval_dataset_id}")
|
||||
print(f"Task: {cfg.task}")
|
||||
print(f"Episodes: {cfg.num_episodes}")
|
||||
print(f"Episode Duration: {cfg.episode_time_sec}s")
|
||||
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||
print(f"Policy Hz: {cfg.fps}")
|
||||
print(f"Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
|
||||
print(f"Interpolation: {cfg.interpolation}")
|
||||
print(f"Device: {cfg.device}")
|
||||
print("=" * 60)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
episode_active = Event()
|
||||
|
||||
# Initialize Robot
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=cfg.follower_left_port,
|
||||
port_right=cfg.follower_right_port,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=DEFAULT_CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
robot = RobotWrapper(follower)
|
||||
logger.info("Follower robot connected")
|
||||
|
||||
# Build Processors and Dataset Features
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create or Load Dataset
|
||||
dataset = None
|
||||
dataset_lock = Lock()
|
||||
|
||||
if cfg.record_dataset:
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / cfg.eval_dataset_id
|
||||
if dataset_path.exists():
|
||||
logger.info(f"Evaluation dataset exists at: {dataset_path}")
|
||||
logger.info("New episodes will be appended.")
|
||||
choice = input("Continue? (y/n): ").strip().lower()
|
||||
if choice != "y":
|
||||
logger.info("Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
return
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=cfg.eval_dataset_id,
|
||||
fps=int(cfg.fps),
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
logger.info(f"Dataset created: {cfg.eval_dataset_id}")
|
||||
|
||||
# Load Policy
|
||||
logger.info(f"Loading policy from: {cfg.model_id}")
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type in ["pi05", "pi0"]:
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Create Action Queue and Start Threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
get_actions_t = Thread(
|
||||
target=get_actions_thread,
|
||||
args=(
|
||||
policy,
|
||||
robot,
|
||||
robot_observation_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_t.start()
|
||||
logger.info("Started action generation thread")
|
||||
|
||||
actor_t = Thread(
|
||||
target=actor_thread,
|
||||
args=(
|
||||
robot,
|
||||
robot_action_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
dataset,
|
||||
dataset_lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_t.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
# Run Evaluation Episodes
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < cfg.num_episodes and not shutdown_event.is_set():
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {cfg.num_episodes}")
|
||||
logger.info(f"\n{'='*40}")
|
||||
logger.info(f"Episode {episode_idx + 1} / {cfg.num_episodes}")
|
||||
logger.info(f"{'='*40}")
|
||||
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
episode_active.set()
|
||||
episode_start_time = time.time()
|
||||
|
||||
while (time.time() - episode_start_time) < cfg.episode_time_sec:
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
elapsed = time.time() - episode_start_time
|
||||
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
|
||||
logger.info(
|
||||
f"[MAIN] Episode progress: {elapsed:.0f}/{cfg.episode_time_sec}s, "
|
||||
f"queue_size={action_queue.qsize()}"
|
||||
)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
episode_active.clear()
|
||||
logger.info(f"Episode {episode_idx + 1} completed")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
logger.info(
|
||||
f"Saving episode {episode_idx + 1} "
|
||||
f"({dataset.episode_buffer['size']} frames)"
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
episode_idx += 1
|
||||
|
||||
# Manual reset between episodes
|
||||
if not shutdown_event.is_set() and episode_idx < cfg.num_episodes:
|
||||
log_say("Waiting for manual reset")
|
||||
logger.info("Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
logger.info(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
episode_active.clear()
|
||||
|
||||
if get_actions_t.is_alive():
|
||||
logger.info("Waiting for action generation thread to finish...")
|
||||
get_actions_t.join(timeout=5.0)
|
||||
|
||||
if actor_t.is_alive():
|
||||
logger.info("Waiting for actor thread to finish...")
|
||||
actor_t.join(timeout=5.0)
|
||||
|
||||
follower.disconnect()
|
||||
logger.info("Follower disconnected")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
dataset.finalize()
|
||||
if cfg.push_to_hub:
|
||||
logger.info("Uploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,216 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
# Friction model parameters from OpenArms config/follower.yaml
|
||||
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
FRICTION_PARAMS = {
|
||||
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
|
||||
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
|
||||
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
|
||||
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
|
||||
}
|
||||
|
||||
# Constants from OpenArms C++ implementation
|
||||
AMP_TMP = 1.0
|
||||
COEF_TMP = 0.1
|
||||
|
||||
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
|
||||
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
|
||||
|
||||
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
|
||||
"""
|
||||
Compute friction torque for a single motor using the tanh friction model.
|
||||
|
||||
Args:
|
||||
velocity_rad_per_sec: Angular velocity in rad/s
|
||||
motor_index: Index of the motor (0-7)
|
||||
|
||||
Returns:
|
||||
Friction torque in N·m (scaled for stability)
|
||||
"""
|
||||
|
||||
Fc = FRICTION_PARAMS["Fc"][motor_index]
|
||||
k = FRICTION_PARAMS["k"][motor_index]
|
||||
Fv = FRICTION_PARAMS["Fv"][motor_index]
|
||||
Fo = FRICTION_PARAMS["Fo"][motor_index]
|
||||
|
||||
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||
friction_torque = (
|
||||
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
|
||||
Fv * velocity_rad_per_sec +
|
||||
Fo
|
||||
)
|
||||
|
||||
# Scale down friction compensation for stability at lower control rates
|
||||
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
|
||||
friction_torque *= FRICTION_SCALE
|
||||
|
||||
return friction_torque
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
print(f"Applying friction compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by friction compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting friction compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# Motor name to index mapping
|
||||
motor_name_to_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions and velocities from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract velocities in degrees per second
|
||||
velocities_deg_per_sec = {}
|
||||
positions_deg = {}
|
||||
|
||||
for motor in follower.bus_right.motors:
|
||||
vel_key = f"right_{motor}.vel"
|
||||
pos_key = f"right_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[pos_key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
vel_key = f"left_{motor}.vel"
|
||||
pos_key = f"left_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[pos_key]
|
||||
|
||||
# Convert velocities to rad/s and compute friction torques
|
||||
friction_torques_nm = {}
|
||||
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
|
||||
# Extract motor name without arm prefix
|
||||
if motor_full_name.startswith("right_"):
|
||||
motor_name = motor_full_name.removeprefix("right_")
|
||||
elif motor_full_name.startswith("left_"):
|
||||
motor_name = motor_full_name.removeprefix("left_")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Get motor index for friction parameters
|
||||
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||
|
||||
# Convert velocity to rad/s
|
||||
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
|
||||
|
||||
# Compute friction torque
|
||||
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
|
||||
friction_torques_nm[motor_full_name] = friction_torque
|
||||
|
||||
# Apply friction compensation to right arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply friction compensation to left arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping friction compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Executable
+142
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from os.path import join, dirname, exists, expanduser
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Load URDF for Pinocchio dynamics
|
||||
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||
|
||||
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
|
||||
pin_robot.data = pin_robot.model.createData()
|
||||
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
|
||||
|
||||
follower.pin_robot = pin_robot
|
||||
|
||||
print(f"Applying gravity compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by gravity compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting gravity compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract positions in degrees
|
||||
positions_deg = {}
|
||||
for motor in follower.bus_right.motors:
|
||||
key = f"right_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
key = f"left_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[key]
|
||||
|
||||
# Convert to radians and calculate gravity torques
|
||||
# Use the built-in method from OpenArmsFollower
|
||||
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
|
||||
torques_nm = follower._gravity_from_q(positions_rad)
|
||||
|
||||
# Apply gravity compensation to right arm (all joints except gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply gravity compensation to left arm (all joints except gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping gravity compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OpenArms Dataset Recording with Gravity + Friction Compensation
|
||||
|
||||
Records a dataset using OpenArms follower robot with leader teleoperator.
|
||||
Leader arms have gravity and friction compensation for weightless, easy movement.
|
||||
Includes 3 cameras: left wrist, right wrist, and base camera.
|
||||
|
||||
Uses the same compensation approach as teleop_with_compensation.py
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
# Recording parameters
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 600
|
||||
RESET_TIME_SEC = 120
|
||||
TASK_DESCRIPTION = "OpenArms task description"
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def record_loop_with_compensation(
|
||||
robot,
|
||||
leader,
|
||||
events,
|
||||
fps,
|
||||
dataset,
|
||||
dataset_features,
|
||||
control_time_s,
|
||||
single_task,
|
||||
display_data=True,
|
||||
):
|
||||
"""
|
||||
Custom record loop that applies gravity + friction compensation to leader.
|
||||
Based on record_loop but with integrated compensation.
|
||||
"""
|
||||
dt = 1 / fps
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
elapsed = loop_start - episode_start_time
|
||||
|
||||
# Check if we should exit
|
||||
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to robot
|
||||
if follower_action:
|
||||
robot.send_action(follower_action)
|
||||
|
||||
# Get observation from robot (includes camera images)
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Add to dataset if we have a dataset
|
||||
if dataset is not None:
|
||||
# Build properly formatted observation frame
|
||||
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
|
||||
|
||||
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
|
||||
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
|
||||
|
||||
# Combine into single frame
|
||||
frame = {**obs_frame, **action_frame}
|
||||
|
||||
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
|
||||
frame["task"] = single_task
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Display data if requested
|
||||
if display_data:
|
||||
log_rerun_data(observation=observation, action=follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main recording loop with gravity compensation."""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Recording with Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Create camera configurations (3 cameras: left wrist, right wrist, base)
|
||||
# Using actual device paths found by lerobot-find-cameras opencv
|
||||
camera_config = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
# Configure follower robot with cameras
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
# Configure leader teleoperator (no cameras needed)
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize robot and teleoperator
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect devices
|
||||
print("Connecting and calibrating...")
|
||||
follower.connect(calibrate=True)
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Verify URDF is loaded for gravity compensation
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
# Configure the dataset features
|
||||
# For actions, we only want to record positions (not velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
action_features = hw_to_dataset_features(action_features_hw, "action")
|
||||
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
print("\nCreating dataset...")
|
||||
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
|
||||
|
||||
# Check if dataset already exists and prompt user
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
while dataset_path.exists():
|
||||
print(f"\nDataset already exists at: {dataset_path}")
|
||||
print("\nOptions:")
|
||||
print(" 1. Overwrite existing dataset")
|
||||
print(" 2. Use a different name")
|
||||
print(" 3. Abort")
|
||||
|
||||
choice = input("\nEnter your choice (1/2/3): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
print(f"Removing existing dataset...")
|
||||
shutil.rmtree(dataset_path)
|
||||
print("✓ Existing dataset removed")
|
||||
break
|
||||
elif choice == '2':
|
||||
print("\nCurrent repo_id:", repo_id)
|
||||
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
|
||||
if new_repo_id and '/' in new_repo_id:
|
||||
repo_id = new_repo_id
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
print(f"✓ Using new repo_id: {repo_id}")
|
||||
# Loop will continue if this new path also exists
|
||||
else:
|
||||
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
|
||||
elif choice == '3':
|
||||
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
return
|
||||
else:
|
||||
print("Invalid choice. Please enter 1, 2, or 3.")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize keyboard listener and visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_recording")
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Recording {NUM_EPISODES} episodes")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print("=" * 70)
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("\nKeyboard controls:")
|
||||
print(" - Press 'q' to stop recording")
|
||||
print(" - Press 'r' to re-record current episode")
|
||||
print("=" * 70)
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Record episode with compensation active
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=None, # Don't save reset period
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Only save episode if frames were recorded
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
else:
|
||||
log_say("No frames recorded, skipping episode save")
|
||||
# Clear the empty buffer
|
||||
dataset.episode_buffer = None
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping recording...")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
# Upload dataset
|
||||
print("\nUploading dataset to Hugging Face Hub...")
|
||||
try:
|
||||
dataset.push_to_hub()
|
||||
print("✓ Dataset uploaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to upload dataset: {e}")
|
||||
print("You can manually upload later using: dataset.push_to_hub()")
|
||||
|
||||
print("✓ Recording complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Dataset Replay Example
|
||||
|
||||
Replays position actions from a recorded dataset on an OpenArms follower robot.
|
||||
Only position commands (ending with .pos) are replayed, not velocity or torque.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/replay.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
# Configuration
|
||||
EPISODE_IDX = 0
|
||||
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
|
||||
DATASET_ROOT = None # Use default cache location, or specify custom path
|
||||
|
||||
# Robot configuration - adjust these to match your setup
|
||||
ROBOT_CONFIG = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for left arm
|
||||
port_right="can3", # CAN interface for right arm
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit: max degrees to move per step
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main replay function."""
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Replay")
|
||||
print("=" * 70)
|
||||
print(f"\nDataset: {DATASET_REPO_ID}")
|
||||
print(f"Episode: {EPISODE_IDX}")
|
||||
print(f"Robot: {ROBOT_CONFIG.id}")
|
||||
print(f" Left arm: {ROBOT_CONFIG.port_left}")
|
||||
print(f" Right arm: {ROBOT_CONFIG.port_right}")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
# Initialize the robot
|
||||
print("\n[1/3] Initializing robot...")
|
||||
robot = OpenArmsFollower(ROBOT_CONFIG)
|
||||
|
||||
# Load the dataset
|
||||
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
|
||||
dataset = LeRobotDataset(
|
||||
DATASET_REPO_ID,
|
||||
root=DATASET_ROOT,
|
||||
episodes=[EPISODE_IDX]
|
||||
)
|
||||
|
||||
# Filter dataset to only include frames from the specified episode
|
||||
# (required for dataset V3.0 where episodes are chunked)
|
||||
episode_frames = dataset.hf_dataset.filter(
|
||||
lambda x: x["episode_index"] == EPISODE_IDX
|
||||
)
|
||||
|
||||
if len(episode_frames) == 0:
|
||||
raise ValueError(
|
||||
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
|
||||
)
|
||||
|
||||
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
|
||||
|
||||
# Extract action features from dataset
|
||||
action_features = dataset.features.get(ACTION, {})
|
||||
action_names = action_features.get("names", [])
|
||||
|
||||
# Filter to only position actions (ending with .pos)
|
||||
position_action_names = [name for name in action_names if name.endswith(".pos")]
|
||||
|
||||
if not position_action_names:
|
||||
raise ValueError(
|
||||
f"No position actions found in dataset. Action names: {action_names}"
|
||||
)
|
||||
|
||||
print(f" Found {len(position_action_names)} position actions to replay")
|
||||
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
|
||||
|
||||
# Select only action columns from dataset
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
print(f"\n[3/3] Connecting to robot...")
|
||||
robot.connect(calibrate=False) # Skip calibration for replay
|
||||
|
||||
if not robot.is_connected:
|
||||
raise RuntimeError("Robot failed to connect!")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Ready to replay!")
|
||||
print("=" * 70)
|
||||
print("\nThe robot will replay the recorded positions.")
|
||||
print("Press Ctrl+C to stop at any time.\n")
|
||||
|
||||
input("Press ENTER to start replaying...")
|
||||
|
||||
# Replay loop
|
||||
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
|
||||
|
||||
try:
|
||||
for idx in range(len(episode_frames)):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Extract action array from dataset
|
||||
action_array = actions[idx][ACTION]
|
||||
|
||||
# Build action dictionary, but only include position actions
|
||||
action = {}
|
||||
for i, name in enumerate(action_names):
|
||||
# Only include position actions (ending with .pos)
|
||||
if name.endswith(".pos"):
|
||||
action[name] = float(action_array[i])
|
||||
|
||||
# Send action to robot
|
||||
robot.send_action(action)
|
||||
|
||||
# Maintain replay rate (use dataset fps)
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
dt_s = 1.0 / dataset.fps - loop_duration
|
||||
busy_wait(dt_s)
|
||||
|
||||
# Progress indicator every 100 frames
|
||||
if (idx + 1) % 100 == 0:
|
||||
progress = (idx + 1) / len(episode_frames) * 100
|
||||
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
|
||||
|
||||
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
|
||||
log_say("Replay complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReplay interrupted by user")
|
||||
finally:
|
||||
# Disconnect robot
|
||||
print("\nDisconnecting robot...")
|
||||
robot.disconnect()
|
||||
print("✓ Replay complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Executable
+73
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Setup all OpenArms CAN interfaces with CAN FD
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "OpenArms CAN FD Interface Setup"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Mode: CAN FD"
|
||||
echo " - Nominal bitrate: 1 Mbps"
|
||||
echo " - Data bitrate: 5 Mbps"
|
||||
echo ""
|
||||
echo "Configuring interfaces can0, can1, can2, can3..."
|
||||
echo ""
|
||||
|
||||
# Configure each CAN interface with CAN FD
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
|
||||
# Check if interface exists
|
||||
if ! ip link show "$interface" &> /dev/null; then
|
||||
echo "⚠ $interface: Not found, skipping"
|
||||
continue
|
||||
fi
|
||||
|
||||
# Bring down interface
|
||||
sudo ip link set "$interface" down 2>/dev/null
|
||||
|
||||
# Configure CAN FD mode
|
||||
sudo ip link set "$interface" type can \
|
||||
bitrate 1000000 \
|
||||
dbitrate 5000000 \
|
||||
fd on
|
||||
|
||||
# Bring up interface
|
||||
sudo ip link set "$interface" up
|
||||
|
||||
# Verify configuration
|
||||
if ip link show "$interface" | grep -q "UP"; then
|
||||
echo "✓ $interface: Configured and UP"
|
||||
else
|
||||
echo "✗ $interface: Failed to bring UP"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verification"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Show detailed status for each interface
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
if ip link show "$interface" &> /dev/null; then
|
||||
echo "$interface:"
|
||||
# Show key parameters
|
||||
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=========================================="
|
||||
echo "Setup Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "All interfaces configured for CAN FD mode"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Test motors: python debug_can_communication.py"
|
||||
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
|
||||
echo ""
|
||||
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
OpenArms Teleoperation Example - Full Dual Arms
|
||||
|
||||
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
|
||||
It first calibrates both devices, then enters a teleoperation loop for both arms.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for follower left arm
|
||||
port_right="can3", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0, # Safety limit
|
||||
)
|
||||
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0", # CAN interface for leader left arm
|
||||
port_right="can1", # CAN interface for leader right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Enable manual control (torque disabled)
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Teleoperation - Full Dual Arms")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize devices
|
||||
print("\n[1/4] Initializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("\n[2/4] Connecting and calibrating follower robot...")
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("\n[3/4] Connecting and calibrating leader arm...")
|
||||
print("Note: The leader arm will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Wait for user to be ready
|
||||
print("\n[4/4] Ready for teleoperation!")
|
||||
print("\nBoth arms will be controlled (16 motors total):")
|
||||
print(" RIGHT ARM: joints 1-7 + gripper")
|
||||
print(" LEFT ARM: joints 1-7 + gripper")
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("\nTeleoperation started! Move both leader arms.")
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
start_time = time.perf_counter()
|
||||
last_print_time = start_time
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get action from leader
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Filter to only position data for all joints (both arms)
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
joint_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to follower (both arms)
|
||||
if joint_action:
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print stats every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
min_time = min(loop_times)
|
||||
max_time = max(loop_times)
|
||||
max_hz = 1.0 / min_time if min_time > 0 else 0
|
||||
min_hz = 1.0 / max_time if max_time > 0 else 0
|
||||
|
||||
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
|
||||
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
|
||||
f"Avg loop time: {avg_time*1000:.1f} ms")
|
||||
|
||||
# Reset for next measurement window
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
OpenArms Mini Teleoperation Example
|
||||
|
||||
This script demonstrates teleoperation of an OpenArms follower robot using
|
||||
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
|
||||
|
||||
The OpenArms Mini has:
|
||||
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
|
||||
Note on gripper normalization:
|
||||
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
|
||||
- OpenArms follower gripper: degrees (0=closed, -65=open)
|
||||
- This script automatically converts between the two ranges
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
|
||||
# Target control frequency
|
||||
TARGET_FPS = 30
|
||||
|
||||
# Configure the OpenArms follower (Damiao motors on CAN bus)
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can0", # CAN interface for follower left arm
|
||||
port_right="can1", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit (degrees per step)
|
||||
)
|
||||
|
||||
# Configure the OpenArms Mini leader (Feetech motors on serial)
|
||||
leader_config = OpenArmsMiniConfig(
|
||||
port_right="/dev/ttyACM0", # Serial port for right arm
|
||||
port_left="/dev/ttyACM1", # Serial port for left arm
|
||||
id="openarms_mini",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
print("OpenArms Mini → OpenArms Follower Teleoperation")
|
||||
|
||||
# Initialize devices
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsMini(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("Note: The leader arms will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
avg_loop_time = 0.0
|
||||
min_loop_time = float('inf')
|
||||
max_loop_time = 0.0
|
||||
stats_update_interval = 1.0 # Update stats every 1 second
|
||||
last_stats_update = time.perf_counter()
|
||||
|
||||
|
||||
SWAPPED_JOINTS = {
|
||||
"right_joint_6": "right_joint_7",
|
||||
"right_joint_7": "right_joint_6",
|
||||
"left_joint_6": "left_joint_7",
|
||||
"left_joint_7": "left_joint_6",
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get actions and observations
|
||||
leader_action = leader.get_action()
|
||||
follower_obs = follower.get_observation()
|
||||
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
|
||||
# Determine which follower joint this leader joint controls
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
# Get leader position (default 0 if missing)
|
||||
pos = leader_action.get(leader_key, 0.0)
|
||||
|
||||
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
|
||||
if "gripper" in joint:
|
||||
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
|
||||
# 0 (closed) -> 0°, 100 (open) -> -65°
|
||||
pos = (pos / 100.0) * -65.0
|
||||
|
||||
# Store in action dict for follower
|
||||
joint_action[follower_key] = pos
|
||||
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Loop timing
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Update stats periodically
|
||||
current_time = time.perf_counter()
|
||||
if current_time - last_stats_update >= stats_update_interval:
|
||||
if loop_times:
|
||||
avg_loop_time = sum(loop_times) / len(loop_times)
|
||||
min_loop_time = min(loop_times)
|
||||
max_loop_time = max(loop_times)
|
||||
loop_times = []
|
||||
last_stats_update = current_time
|
||||
|
||||
# Display everything
|
||||
sys.stdout.write("\033[H\033[J") # Clear screen
|
||||
|
||||
# Show timing stats at the top
|
||||
if avg_loop_time > 0:
|
||||
avg_hz = 1.0 / avg_loop_time
|
||||
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
|
||||
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
|
||||
else:
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
|
||||
|
||||
# Show joint positions
|
||||
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
|
||||
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
|
||||
print("-" * 52)
|
||||
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
leader_pos = leader_action.get(leader_key, 0.0)
|
||||
follower_pos = follower_obs.get(follower_key, 0.0)
|
||||
|
||||
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
|
||||
|
||||
# Smart sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
+202
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
OpenArms Teleoperation with Gravity + Friction Compensation
|
||||
|
||||
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
|
||||
Follower arms (both LEFT and RIGHT): Mirror leader movements
|
||||
|
||||
Uses the URDF file from the lerobot repository.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def main():
|
||||
"""Main teleoperation loop with gravity compensation"""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Teleoperation with Gravity Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Configuration
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
)
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize and connect
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# URDF is automatically loaded in the leader constructor
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("Press ENTER to start...")
|
||||
input()
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Main control loop
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Performance monitoring
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
This script:
|
||||
1. Loads a dataset
|
||||
2. Sets all task_index to 0 and task description to "fold"
|
||||
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
|
||||
|
||||
Usage:
|
||||
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
write_info,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
|
||||
# Single unified task
|
||||
UNIFIED_TASK = "fold"
|
||||
|
||||
|
||||
def unify_dataset_tasks(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
) -> None:
|
||||
"""Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
Args:
|
||||
repo_id: Dataset repository ID.
|
||||
root: Optional root path for dataset.
|
||||
push_to_hub: Whether to push the result to HuggingFace Hub.
|
||||
"""
|
||||
input_root = root if root else HF_LEROBOT_HOME / repo_id
|
||||
input_repo_id = repo_id
|
||||
|
||||
logging.info(f"Loading metadata from {repo_id}")
|
||||
|
||||
# Load source metadata
|
||||
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
|
||||
|
||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||
|
||||
# Modify in-place (input_root == output_root supported)
|
||||
data_dir = input_root / DATA_DIR
|
||||
|
||||
# Process data files - set all task_index to 0
|
||||
logging.info("Processing data files (in-place)...")
|
||||
for parquet_file in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Processing data"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["task_index"] = 0 # All tasks unified to index 0
|
||||
df.to_parquet(parquet_file)
|
||||
|
||||
# Process episodes metadata - set all tasks to unified task
|
||||
logging.info("Processing episodes metadata (in-place)...")
|
||||
episodes_dir = input_root / "meta" / "episodes"
|
||||
if episodes_dir.exists():
|
||||
for parquet_file in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
||||
df.to_parquet(parquet_file)
|
||||
else:
|
||||
logging.warning(f"No episodes directory found at {episodes_dir}, skipping")
|
||||
|
||||
# Update tasks.parquet with single task
|
||||
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
||||
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
||||
write_tasks(new_tasks, input_root)
|
||||
|
||||
# Update info.json
|
||||
new_info = src_meta.info.copy()
|
||||
new_info["total_tasks"] = 1
|
||||
write_info(new_info, input_root)
|
||||
|
||||
logging.info(f"Dataset modified in-place at {input_root}")
|
||||
logging.info(f"Task: {UNIFIED_TASK}")
|
||||
|
||||
if push_to_hub:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.info(f"Pushing {input_repo_id} to hub")
|
||||
dataset = LeRobotDataset(input_repo_id, root=input_root)
|
||||
dataset.push_to_hub(private=True)
|
||||
logging.info("Push complete!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push result to HuggingFace Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
unify_dataset_tasks(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,745 @@
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
|
||||
main {
|
||||
min-height: 100vh;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0 0 1rem 0;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: #666;
|
||||
margin: 0 0 0.5rem 0;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1920px;
|
||||
margin: 0 auto;
|
||||
display: grid;
|
||||
grid-template-columns: minmax(500px, 600px) 1fr;
|
||||
gap: 2rem;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
/* Left column container */
|
||||
.left-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Right column container */
|
||||
.right-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Responsive: Stack on smaller screens */
|
||||
@media (max-width: 1200px) {
|
||||
.container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.config-panel {
|
||||
border: 2px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.config-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.config-header:hover {
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.toggle-icon {
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.config-content {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.robot-setup {
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.robot-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.robot-status.ready {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border: 1px solid #10b981;
|
||||
}
|
||||
|
||||
.robot-status.not-ready {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border: 1px solid #f59e0b;
|
||||
}
|
||||
|
||||
.btn-setup {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-setup:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-setup:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-zero {
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-zero:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
}
|
||||
|
||||
.btn-zero:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.zero-position-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-zero-large {
|
||||
width: 100%;
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
|
||||
}
|
||||
|
||||
.btn-zero-large:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-zero-large:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-episode-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-delete {
|
||||
width: 100%;
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
|
||||
}
|
||||
|
||||
.btn-delete:hover:not(:disabled) {
|
||||
background: #dc2626;
|
||||
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-delete:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.btn-disconnect {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-disconnect:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.4rem 0.8rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-refresh:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-refresh:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
border: 2px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 1rem 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1.5rem;
|
||||
font-weight: 500;
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.status-banner.initializing {
|
||||
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
|
||||
color: #1e40af;
|
||||
border-left: 4px solid #3b82f6;
|
||||
}
|
||||
|
||||
.status-banner.encoding {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border-left: 4px solid #f59e0b;
|
||||
}
|
||||
|
||||
.status-banner.uploading {
|
||||
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
|
||||
color: #3730a3;
|
||||
border-left: 4px solid #6366f1;
|
||||
}
|
||||
|
||||
.status-banner.success {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border-left: 4px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner.warning {
|
||||
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
|
||||
color: #991b1b;
|
||||
border-left: 4px solid #ef4444;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 3px solid rgba(0, 0, 0, 0.1);
|
||||
border-top-color: currentColor;
|
||||
border-radius: 50%;
|
||||
animation: spin 0.8s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.control-horizontal {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.control-left {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.control-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
input[type="text"]:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #10b981;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.btn-set-task {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.btn-set-task:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-set-task:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-start {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-start:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-start:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-stop {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-stop:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-reset {
|
||||
padding: 0.5rem 1rem;
|
||||
background: #6b7280;
|
||||
color: white;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.btn-reset:hover {
|
||||
background: #4b5563;
|
||||
}
|
||||
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.status.recording {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.status.recording.recording-active {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
background: #dc2626;
|
||||
color: white;
|
||||
padding: 1.5rem;
|
||||
border: 4px solid #991b1b;
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
|
||||
font-weight: 700;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.status.recording.recording-active .indicator {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #fef2f2;
|
||||
animation: pulse-strong 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-strong {
|
||||
0%, 100% {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
opacity: 0.7;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
}
|
||||
|
||||
.status.recording.recording-active .time-display {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.fps-display {
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
opacity: 0.95;
|
||||
}
|
||||
|
||||
.fps-warning {
|
||||
color: #fef2f2;
|
||||
animation: pulse-warning 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warning {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.status.recording.recording-active .btn-stop {
|
||||
align-self: stretch;
|
||||
}
|
||||
|
||||
.ramp-up-countdown {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.countdown-box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 2rem 3rem;
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
border: 4px solid #f59e0b;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
min-width: 280px;
|
||||
animation: pulse-warm 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warm {
|
||||
0%, 100% {
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
}
|
||||
50% {
|
||||
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
|
||||
}
|
||||
}
|
||||
|
||||
.countdown-label {
|
||||
font-size: 1rem;
|
||||
color: #92400e;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1.5px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.countdown-value {
|
||||
font-size: 4.5rem;
|
||||
font-weight: 900;
|
||||
color: #d97706;
|
||||
font-family: 'Courier New', monospace;
|
||||
line-height: 1;
|
||||
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.countdown-subtitle {
|
||||
font-size: 0.875rem;
|
||||
color: #78350f;
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
text-align: center;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.status.idle {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ef4444;
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.counter {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
|
||||
border-radius: 8px;
|
||||
border: 2px solid #e5e7eb;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.counter-label {
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.counter-value {
|
||||
font-size: 3rem;
|
||||
font-weight: 700;
|
||||
color: #10b981;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.time-display {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
font-family: 'Courier New', monospace;
|
||||
}
|
||||
|
||||
.error-box {
|
||||
padding: 1rem;
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 4px;
|
||||
border-left: 4px solid #ef4444;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.config-section {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.config-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #374151;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
select {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
}
|
||||
|
||||
select:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Camera Layout */
|
||||
.camera-layout {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-base {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera-wrist-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-wrist {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera {
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.camera h3 {
|
||||
padding: 0.75rem;
|
||||
background: #f9fafb;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.camera img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
display: block;
|
||||
background: #000;
|
||||
min-height: 300px;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.camera-placeholder {
|
||||
text-align: center;
|
||||
padding: 4rem 2rem;
|
||||
background: #f9fafb;
|
||||
border-radius: 4px;
|
||||
border: 2px dashed #d1d5db;
|
||||
}
|
||||
|
||||
.camera-placeholder p {
|
||||
margin: 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.camera-placeholder p:first-child {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 500;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.hint {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,857 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import './App.css';
|
||||
|
||||
const API_BASE = 'http://localhost:8000/api';
|
||||
|
||||
function App() {
|
||||
// State
|
||||
const [task, setTask] = useState('');
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isInitializing, setIsInitializing] = useState(false);
|
||||
const [isEncoding, setIsEncoding] = useState(false);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [robotsReady, setRobotsReady] = useState(false);
|
||||
const [elapsedTime, setElapsedTime] = useState(0);
|
||||
const [currentFps, setCurrentFps] = useState(0);
|
||||
const [loopFps, setLoopFps] = useState(0);
|
||||
const [episodeCount, setEpisodeCount] = useState(0);
|
||||
const [error, setError] = useState(null);
|
||||
const [statusMessage, setStatusMessage] = useState('Ready');
|
||||
const [uploadStatus, setUploadStatus] = useState(null);
|
||||
const [rampUpRemaining, setRampUpRemaining] = useState(0);
|
||||
const [movingToZero, setMovingToZero] = useState(false);
|
||||
const [configExpanded, setConfigExpanded] = useState(false);
|
||||
const [latestRepoId, setLatestRepoId] = useState(null);
|
||||
|
||||
// Configuration
|
||||
const [config, setConfig] = useState({
|
||||
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
|
||||
leader_left: 'can0',
|
||||
leader_right: 'can1',
|
||||
follower_left: 'can2',
|
||||
follower_right: 'can3',
|
||||
left_wrist: '/dev/video0',
|
||||
right_wrist: '/dev/video1',
|
||||
base: '/dev/video4'
|
||||
});
|
||||
|
||||
// Available options
|
||||
const [availableCameras, setAvailableCameras] = useState([]);
|
||||
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
|
||||
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
|
||||
|
||||
const statusIntervalRef = useRef(null);
|
||||
const hasInitializedRef = useRef(false);
|
||||
|
||||
const loadConfig = () => {
|
||||
try {
|
||||
const saved = localStorage.getItem('openarms_config');
|
||||
if (saved) {
|
||||
const loadedConfig = JSON.parse(saved);
|
||||
setConfig(prev => ({ ...prev, ...loadedConfig }));
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Load config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const saveConfig = (newConfig) => {
|
||||
try {
|
||||
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
|
||||
} catch (e) {
|
||||
console.error('Save config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch status periodically
|
||||
const fetchStatus = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/status`);
|
||||
const data = await response.json();
|
||||
|
||||
setIsRecording(data.is_recording);
|
||||
setIsInitializing(data.is_initializing);
|
||||
setIsEncoding(data.is_encoding);
|
||||
setIsUploading(data.is_uploading);
|
||||
setRobotsReady(data.robots_ready);
|
||||
setElapsedTime(data.elapsed_time);
|
||||
setCurrentFps(data.current_fps || 0);
|
||||
setLoopFps(data.loop_fps || 0);
|
||||
setEpisodeCount(data.episode_count);
|
||||
setError(data.error);
|
||||
setStatusMessage(data.status_message || 'Ready');
|
||||
setUploadStatus(data.upload_status);
|
||||
setRampUpRemaining(data.ramp_up_remaining || 0);
|
||||
setMovingToZero(data.moving_to_zero || false);
|
||||
|
||||
// Track the latest repo_id from the backend
|
||||
if (data.latest_repo_id) {
|
||||
setLatestRepoId(data.latest_repo_id);
|
||||
}
|
||||
|
||||
if (data.config) {
|
||||
// Only merge server config if we don't have a saved config (first load)
|
||||
if (!localStorage.getItem('openarms_config')) {
|
||||
setConfig(prev => {
|
||||
const merged = { ...data.config, ...prev };
|
||||
localStorage.setItem('openarms_config', JSON.stringify(merged));
|
||||
return merged;
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch status:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const setupRobots = async () => {
|
||||
// Show warning to verify camera positions
|
||||
const confirmed = window.confirm(
|
||||
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
|
||||
'📹 Check that cameras are correctly positioned:\n' +
|
||||
' • LEFT wrist camera is actually on the LEFT arm\n' +
|
||||
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
|
||||
' • BASE camera is actually the BASE/overhead camera\n\n' +
|
||||
'Incorrect camera positioning will result in invalid training data!\n\n' +
|
||||
'Click OK to continue with robot setup, or Cancel to review configuration.'
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return; // User cancelled, don't proceed
|
||||
}
|
||||
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/setup`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(config)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to setup robots');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(`Robot setup failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Disconnect robots
|
||||
const disconnectRobots = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
|
||||
setRobotsReady(false);
|
||||
} catch (e) {
|
||||
console.error('Failed to disconnect robots:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover cameras
|
||||
const discoverCameras = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/cameras/discover`);
|
||||
const data = await response.json();
|
||||
const cameras = data.cameras || [];
|
||||
setAvailableCameras(cameras);
|
||||
|
||||
// Get list of valid camera IDs
|
||||
const validCameraIds = cameras.map(cam => String(cam.id));
|
||||
|
||||
// Auto-fix config if current values are invalid or not set
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
// Auto-fix invalid camera config
|
||||
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
|
||||
if (cameras.length >= 1) {
|
||||
updated.left_wrist = String(cameras[0].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
|
||||
if (cameras.length >= 2) {
|
||||
updated.right_wrist = String(cameras[1].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.base || !validCameraIds.includes(config.base)) {
|
||||
if (cameras.length >= 3) {
|
||||
updated.base = String(cameras[2].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
|
||||
if (cameras.length === 0) {
|
||||
setError('No cameras detected! Please connect cameras and refresh.');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover cameras:', e);
|
||||
setError(`Camera discovery failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover USB ports
|
||||
const discoverUsbPorts = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/usb/discover`);
|
||||
const data = await response.json();
|
||||
const ports = data.ports || [];
|
||||
setAvailableUsbPorts(ports);
|
||||
|
||||
// Auto-fix config if OpenArms Mini is selected and ports are invalid
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
|
||||
updated.leader_left = ports[0];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
|
||||
updated.leader_right = ports[1];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
}
|
||||
|
||||
if (ports.length === 0) {
|
||||
console.warn('No USB ports detected for OpenArms Mini');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover USB ports:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Set task only (for pedal use)
|
||||
const setTaskOnly = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/set-task`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to set task');
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
setStatusMessage(result.message || `Task set: ${task}`);
|
||||
saveConfig(config);
|
||||
|
||||
// Clear success message after 3 seconds
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Start recording
|
||||
const startRecording = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/start`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to start recording');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Stop recording
|
||||
const stopRecording = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/stop`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to stop recording');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setError(null);
|
||||
// Update latest repo_id after recording
|
||||
if (data.dataset_name) {
|
||||
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
|
||||
}
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteLatestEpisode = async () => {
|
||||
if (!latestRepoId) {
|
||||
setError('No episode to delete');
|
||||
return;
|
||||
}
|
||||
|
||||
const confirmed = window.confirm(
|
||||
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to delete episode');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setLatestRepoId(null);
|
||||
setEpisodeCount(Math.max(0, episodeCount - 1));
|
||||
setStatusMessage(`Deleted: ${data.deleted_repo}`);
|
||||
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(`Delete failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Reset counter
|
||||
const resetCounter = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
|
||||
setEpisodeCount(0);
|
||||
} catch (e) {
|
||||
console.error('Failed to reset counter:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Move robot to zero position
|
||||
const moveToZero = async () => {
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to move to zero position');
|
||||
}
|
||||
await response.json();
|
||||
} catch (e) {
|
||||
setError(`Move to zero failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Format time as MM:SS
|
||||
const formatTime = (seconds) => {
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
// Update config and save
|
||||
const updateConfig = (key, value) => {
|
||||
const updated = { ...config, [key]: value };
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
};
|
||||
|
||||
// Initialize on mount only
|
||||
useEffect(() => {
|
||||
// Prevent double-initialization in development
|
||||
if (hasInitializedRef.current) {
|
||||
return;
|
||||
}
|
||||
hasInitializedRef.current = true;
|
||||
|
||||
loadConfig();
|
||||
discoverCameras();
|
||||
discoverUsbPorts();
|
||||
fetchStatus();
|
||||
statusIntervalRef.current = setInterval(fetchStatus, 1000);
|
||||
|
||||
return () => {
|
||||
if (statusIntervalRef.current) {
|
||||
clearInterval(statusIntervalRef.current);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []); // Run only once on mount
|
||||
|
||||
// Discover USB ports when leader type changes to Mini
|
||||
useEffect(() => {
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
discoverUsbPorts();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [config.leader_type]);
|
||||
|
||||
return (
|
||||
<main>
|
||||
<header>
|
||||
<h1>OpenArms Recording</h1>
|
||||
</header>
|
||||
|
||||
<div className="container">
|
||||
{/* Left Column: Configuration and Recording Control */}
|
||||
<div className="left-column">
|
||||
{/* Configuration Panel */}
|
||||
<section className="panel config-panel">
|
||||
<div
|
||||
className="config-header"
|
||||
onClick={() => setConfigExpanded(!configExpanded)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
|
||||
>
|
||||
<h2>⚙️ Configuration</h2>
|
||||
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
|
||||
</div>
|
||||
|
||||
{configExpanded && (
|
||||
<div className="config-content">
|
||||
{/* Robot Setup */}
|
||||
<div className="config-section">
|
||||
<h3>🤖 Robot Setup</h3>
|
||||
<div className="robot-setup">
|
||||
{robotsReady ? (
|
||||
<div className="robot-status ready">
|
||||
<span>✅ Robots Ready - Recording will start instantly</span>
|
||||
<button onClick={disconnectRobots} className="btn-disconnect">
|
||||
Disconnect Robots
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="robot-status not-ready">
|
||||
<span>⚠️ Robots not initialized - Recording will take ~10 seconds</span>
|
||||
<button
|
||||
onClick={setupRobots}
|
||||
disabled={isRecording || isInitializing}
|
||||
className="btn-setup"
|
||||
>
|
||||
🚀 Setup Robots
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Type Selection */}
|
||||
<div className="config-section">
|
||||
<h3>🎮 Leader Type</h3>
|
||||
<div className="config-grid">
|
||||
<label style={{gridColumn: '1 / -1'}}>
|
||||
Leader Arm Type
|
||||
<select
|
||||
value={config.leader_type}
|
||||
onChange={(e) => updateConfig('leader_type', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
|
||||
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Interfaces (CAN or USB based on type) */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>
|
||||
{config.leader_type === 'openarms_mini'
|
||||
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
|
||||
: 'Leader Interfaces (CAN)'}
|
||||
</h3>
|
||||
{config.leader_type === 'openarms_mini' && (
|
||||
<button
|
||||
onClick={discoverUsbPorts}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Leader Left
|
||||
<select
|
||||
value={config.leader_left}
|
||||
onChange={(e) => updateConfig('leader_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Leader Right
|
||||
<select
|
||||
value={config.leader_right}
|
||||
onChange={(e) => updateConfig('leader_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Follower CAN Interfaces */}
|
||||
<div className="config-section">
|
||||
<h3>Follower Interfaces (CAN)</h3>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Follower Left
|
||||
<select
|
||||
value={config.follower_left}
|
||||
onChange={(e) => updateConfig('follower_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Follower Right
|
||||
<select
|
||||
value={config.follower_right}
|
||||
onChange={(e) => updateConfig('follower_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Camera Configuration */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
|
||||
<button
|
||||
onClick={discoverCameras}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
</div>
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Left Wrist
|
||||
<select
|
||||
value={config.left_wrist}
|
||||
onChange={(e) => updateConfig('left_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Right Wrist
|
||||
<select
|
||||
value={config.right_wrist}
|
||||
onChange={(e) => updateConfig('right_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Base Camera
|
||||
<select
|
||||
value={config.base}
|
||||
onChange={(e) => updateConfig('base', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* Control Panel */}
|
||||
<section className="panel control-panel">
|
||||
<h2>🎬 Recording Control</h2>
|
||||
|
||||
{/* Status Banner - Always show important statuses */}
|
||||
{isInitializing && (
|
||||
<div className="status-banner initializing">
|
||||
<div className="spinner"></div>
|
||||
<span>{statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isEncoding && (
|
||||
<div className="status-banner encoding">
|
||||
<div className="spinner"></div>
|
||||
<span>📹 {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isUploading && (
|
||||
<div className="status-banner uploading">
|
||||
<div className="spinner"></div>
|
||||
<span>☁️ {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
|
||||
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
|
||||
<span>{uploadStatus}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="control-horizontal">
|
||||
{/* Task Input and Status */}
|
||||
<div className="control-left">
|
||||
<div className="input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={task}
|
||||
onChange={(e) => setTask(e.target.value)}
|
||||
placeholder="Task description (e.g., 'pick and place')"
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading}
|
||||
onKeyPress={(e) => {
|
||||
if (e.key === 'Enter' && robotsReady) {
|
||||
setTaskOnly();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={setTaskOnly}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-set-task"
|
||||
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
|
||||
>
|
||||
💾 Set Task
|
||||
</button>
|
||||
<button
|
||||
onClick={startRecording}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-start"
|
||||
title={!robotsReady ? 'Please setup robots first' : ''}
|
||||
>
|
||||
{isInitializing
|
||||
? '⏳ Initializing...'
|
||||
: isRecording
|
||||
? '⏺ Recording...'
|
||||
: robotsReady
|
||||
? '⏺ Start Recording'
|
||||
: '⏺ Setup Robots First'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Ramp-up Countdown */}
|
||||
{isRecording && rampUpRemaining > 0 && (
|
||||
<div className="ramp-up-countdown">
|
||||
<div className="countdown-box">
|
||||
<div className="countdown-label">⚡ WARMING UP - PID RAMP-UP</div>
|
||||
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
|
||||
<div className="countdown-subtitle">Recording will start automatically...</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recording Status - Only show after ramp-up */}
|
||||
{isRecording && rampUpRemaining <= 0 && (
|
||||
<div className="status recording recording-active">
|
||||
<div className="indicator"></div>
|
||||
<div className="time-display">
|
||||
<span>{formatTime(elapsedTime)}</span>
|
||||
<span className="fps-display">
|
||||
Loop: {loopFps.toFixed(1)} Hz
|
||||
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> ⚠️</span>}
|
||||
</span>
|
||||
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
|
||||
</div>
|
||||
<button onClick={stopRecording} className="btn-stop">
|
||||
⏹ Stop
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Episode Counter */}
|
||||
<div className="control-right">
|
||||
<div className="counter">
|
||||
<div className="counter-label">Episodes Recorded</div>
|
||||
<div className="counter-value">{episodeCount}</div>
|
||||
<button onClick={resetCounter} className="btn-reset">
|
||||
Reset
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Delete Latest Episode Button */}
|
||||
{!isRecording && !isInitializing && latestRepoId && (
|
||||
<div className="delete-episode-section">
|
||||
<button
|
||||
onClick={deleteLatestEpisode}
|
||||
className="btn-delete"
|
||||
title="Delete the latest recorded episode from HuggingFace Hub"
|
||||
>
|
||||
Delete Latest Episode
|
||||
</button>
|
||||
<div className="delete-info">Will delete: {latestRepoId}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Move to Zero Button */}
|
||||
{robotsReady && !isRecording && !isInitializing && (
|
||||
<div className="zero-position-section">
|
||||
<button
|
||||
onClick={moveToZero}
|
||||
disabled={movingToZero}
|
||||
className="btn-zero-large"
|
||||
title="Move both leader and follower robots to zero position (2s)"
|
||||
>
|
||||
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="error-box">
|
||||
⚠️ {error}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
{/* Right Column: Camera Feeds */}
|
||||
<div className="right-column">
|
||||
<section className="panel cameras">
|
||||
<h2>📹 Camera Views</h2>
|
||||
{robotsReady || isRecording || isInitializing ? (
|
||||
<div className="camera-layout">
|
||||
{/* Base camera - full width */}
|
||||
<div className="camera camera-base">
|
||||
<h3>Base Camera</h3>
|
||||
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
|
||||
</div>
|
||||
|
||||
{/* Wrist cameras - side by side */}
|
||||
<div className="camera-wrist-container">
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Left Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
|
||||
</div>
|
||||
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Right Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="camera-placeholder">
|
||||
<p>📷 Camera feeds will appear when robots are set up</p>
|
||||
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# OpenArms Web Recording Interface
|
||||
|
||||
A web interface for recording OpenArms datasets.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd examples/openarms_web_interface
|
||||
npm install
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**Start everything with one command:**
|
||||
|
||||
```bash
|
||||
./launch.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Start the FastAPI backend on port 8000
|
||||
- Start the React frontend on port 5173
|
||||
- Show live logs from both services
|
||||
|
||||
Then open your browser to: **http://localhost:5173**
|
||||
|
||||
**Stop with:** `Ctrl+C`
|
||||
|
||||
---
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
|
||||
2. Click **"Setup Robots"** to initialize (once at start)
|
||||
3. Enter a **task description**
|
||||
4. Click **"Start Recording"** to begin an episode
|
||||
5. Click **"Stop Recording"** when done
|
||||
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
|
||||
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
|
||||
|
||||
---
|
||||
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>OpenArms Recording Interface</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
Executable
+142
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenArms Web Interface Launcher
|
||||
# Starts Rerun viewer, FastAPI backend, and React frontend
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
|
||||
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo -e "${YELLOW}Shutting down services...${NC}"
|
||||
|
||||
# Kill all child processes
|
||||
pkill -P $$ 2>/dev/null || true
|
||||
|
||||
# Kill specific services by port
|
||||
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
|
||||
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
|
||||
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
|
||||
|
||||
echo -e "${GREEN}✓ Services stopped${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Register cleanup on script exit
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Check if required commands exist
|
||||
command -v rerun >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v python >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'python' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v npm >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'npm' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check if node_modules exists
|
||||
if [ ! -d "node_modules" ]; then
|
||||
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
|
||||
npm install
|
||||
echo -e "${GREEN}✓ Dependencies installed${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Starting services...${NC}"
|
||||
echo ""
|
||||
|
||||
# 1. Start FastAPI backend (Rerun will start when recording begins)
|
||||
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use Python from current environment (if lerobot env is active, it will use that)
|
||||
# Otherwise, check if we need to use conda run
|
||||
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
|
||||
# Already in lerobot environment
|
||||
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
|
||||
PYTHON_CMD="python"
|
||||
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
|
||||
# lerobot env exists but not active - use conda run
|
||||
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
|
||||
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
|
||||
else
|
||||
# Fall back to system python
|
||||
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
|
||||
BACKEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $BACKEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start backend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 2. Start React frontend
|
||||
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
|
||||
cd "$SCRIPT_DIR"
|
||||
npm run dev > /tmp/openarms_frontend.log 2>&1 &
|
||||
FRONTEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $FRONTEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start frontend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Display status
|
||||
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
|
||||
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
|
||||
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Logs:${NC}"
|
||||
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
|
||||
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
|
||||
echo ""
|
||||
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
|
||||
echo ""
|
||||
|
||||
# Keep script running and wait for any service to exit
|
||||
wait
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<App />
|
||||
)
|
||||
|
||||
+1955
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "openarms-web-interface",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"vite": "^6.0.1"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
strictPort: false,
|
||||
host: true,
|
||||
open: false
|
||||
},
|
||||
build: {
|
||||
outDir: 'dist',
|
||||
sourcemap: true
|
||||
}
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
python examples/rac/rac_data_collection_openarms_rtc.py --robot.type=openarms_follower --robot.port_right=can1 --robot.port_left=can0 --robot.cameras="{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}" --teleop.type=openarms_mini --teleop.port_right=/dev/ttyACM0 --teleop.port_left=/dev/ttyACM1 --policy.path=lerobot-data-collection/level1_rac3_100k --dataset.repo_id=lerobot-data-collection/level1_rac3_rtc_s5_2 --dataset.single_task="Fold the T-shirt properly" --dataset.num_episodes=5
|
||||
@@ -0,0 +1,638 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection with Policy Rollout + Human Intervention.
|
||||
|
||||
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
|
||||
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot.
|
||||
|
||||
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
|
||||
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Key RaC Rules:
|
||||
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
|
||||
- Rule 2 (Terminate after Intervention): Episode ends after correction
|
||||
|
||||
The recovery segment (teleoperating back to good state) is recorded as training data -
|
||||
this teaches the policy how to recover from errors.
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessor,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCConfig:
|
||||
robot: RobotConfig
|
||||
dataset: RaCDatasetConfig
|
||||
policy: PreTrainedConfig
|
||||
teleop: TeleoperatorConfig
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
|
||||
"correction_active": False, # 'c' pressed - human controlling, recording correction
|
||||
"in_reset": False, # True during reset period
|
||||
"start_next_episode": False, # Signal to start next episode
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
# During reset: any action key starts next episode
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A" # Left pedal
|
||||
KEY_RIGHT = "KEY_C" # Right pedal
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
print(f"[Pedal] Right=pause/next, Left=take control/start")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
# Only trigger on key down
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
# During reset: either pedal starts next episode
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
print("\n[Pedal] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if code == KEY_RIGHT:
|
||||
# Right pedal: SPACE (pause) when running, → (next) when in correction
|
||||
if events["correction_active"]:
|
||||
print("\n[Pedal] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
|
||||
print(" Press left pedal to take control")
|
||||
events["policy_paused"] = True
|
||||
|
||||
elif code == KEY_LEFT:
|
||||
# Left pedal: START (take control) when paused
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[Pedal] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move all robot joints to zero position."""
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
target_pos = {k: 0.0 for k in current_pos}
|
||||
|
||||
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
|
||||
steps = int(duration_s * fps)
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
|
||||
robot.send_action(interp_pos)
|
||||
time.sleep(1 / fps)
|
||||
print("[RaC] Robot at zero position.")
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
fps: int,
|
||||
control_time_s: float,
|
||||
single_task: str,
|
||||
display_data: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
RaC rollout loop with two-stage intervention:
|
||||
|
||||
1. Policy runs autonomously (recording)
|
||||
2. SPACE: Policy pauses (NOT recording) - robot holds position
|
||||
3. 'c': Human takes control (recording correction)
|
||||
4. →: End episode
|
||||
"""
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
device = get_safe_torch_device(policy.config.device)
|
||||
frame_buffer = []
|
||||
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
last_robot_action = None
|
||||
was_paused = False
|
||||
was_correction_active = False
|
||||
waiting_for_takeover = False
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Detect transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press START to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
|
||||
# Wait for start button before enabling correction mode
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Start pressed - enabling teleop control...")
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
was_correction_active = True
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling - record correction data
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
# Waiting for START - policy stopped, no recording, robot holds position
|
||||
if last_robot_action is not None:
|
||||
robot.send_action(last_robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
# Paused and user acknowledged - hold last position, don't record
|
||||
if last_robot_action is not None:
|
||||
robot.send_action(last_robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
robot_action = last_robot_action
|
||||
|
||||
else:
|
||||
# Normal policy execution - record
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot.send_action(robot_action)
|
||||
last_robot_action = robot_action
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if display_data and robot_action is not None:
|
||||
log_rerun_data(observation=obs, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
):
|
||||
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET - Moving teleop to robot position...")
|
||||
print("=" * 65)
|
||||
|
||||
# Enter reset mode
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
# Move teleop to match robot position to avoid sudden jumps
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
# Stage 1: Wait for user to press start to enable teleoperation
|
||||
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
# Stage 2: Enable teleop and let user move robot to starting position
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - move robot to starting position")
|
||||
print(" Press any key/pedal to start next episode")
|
||||
|
||||
# Wait for user to signal ready for next episode
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
# Exit reset mode and clear flags for next episode
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_collection")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC (Recovery and Correction) Data Collection")
|
||||
print("=" * 65)
|
||||
print(" Policy runs autonomously until you intervene.")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy (robot holds position, no recording)")
|
||||
print(" c - Take control (start correction, recording)")
|
||||
print(" → - End episode (save)")
|
||||
print(" ← - Re-record episode")
|
||||
print(" ESC - Stop session and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
|
||||
|
||||
stats = rac_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
# Reset between episodes
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
rac_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection for OpenArms Robot.
|
||||
|
||||
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
|
||||
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot with OpenArms.
|
||||
|
||||
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
|
||||
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously (teleop is idle/free)
|
||||
2. Press SPACE to pause - teleop moves to match robot position
|
||||
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Key RaC Rules:
|
||||
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
|
||||
- Rule 2 (Terminate after Intervention): Episode ends after correction
|
||||
|
||||
The recovery segment (teleoperating back to good state) is recorded as training data -
|
||||
this teaches the policy how to recover from errors.
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (teleop mirrors robot, no recording)
|
||||
c - Take control (teleop free, recording correction)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection_openarms.py \
|
||||
--robot.type=openarms_follower \
|
||||
--robot.port_right=can0 \
|
||||
--robot.port_left=can1 \
|
||||
--robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=openarms_mini \
|
||||
--teleop.port_right=/dev/ttyUSB0 \
|
||||
--teleop.port_left=/dev/ttyUSB1 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_openarms_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCConfig:
|
||||
robot: RobotConfig
|
||||
dataset: RaCDatasetConfig
|
||||
teleop: TeleoperatorConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
|
||||
"correction_active": False, # 'c' pressed - human controlling, recording correction
|
||||
"in_reset": False, # True during reset period
|
||||
"start_next_episode": False, # Signal to start next episode
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
# During reset: any action key starts next episode
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A" # Left pedal
|
||||
KEY_RIGHT = "KEY_C" # Right pedal
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
print(f"[Pedal] Right=pause/next, Left=take control/start")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
# Only trigger on key down
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
# During reset: either pedal starts next episode
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
print("\n[Pedal] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if code == KEY_RIGHT:
|
||||
# Right pedal: SPACE (pause) when running, → (next) when in correction
|
||||
if events["correction_active"]:
|
||||
print("\n[Pedal] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
|
||||
print(" Press left pedal to take control")
|
||||
events["policy_paused"] = True
|
||||
|
||||
elif code == KEY_LEFT:
|
||||
# Left pedal: START (take control) when paused
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[Pedal] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move all robot joints to zero position."""
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
target_pos = {k: 0.0 for k in current_pos}
|
||||
|
||||
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
|
||||
steps = int(duration_s * fps)
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
|
||||
robot.send_action(interp_pos)
|
||||
time.sleep(1 / fps)
|
||||
print("[RaC] Robot at zero position.")
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
fps: int,
|
||||
control_time_s: float,
|
||||
single_task: str,
|
||||
display_data: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
RaC rollout loop with two-stage intervention:
|
||||
|
||||
1. Policy runs autonomously (recording) - teleop free/idle
|
||||
2. SPACE: Policy pauses, teleop mirrors robot position (NOT recording)
|
||||
3. 'c': Human takes control, teleop torque disabled (recording correction)
|
||||
4. →: End episode
|
||||
|
||||
This allows smooth handoff - teleop tracks robot only when paused.
|
||||
"""
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
device = get_safe_torch_device(policy.config.device)
|
||||
frame_buffer = []
|
||||
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
# Start with teleop torque disabled - only enable when paused to track robot
|
||||
teleop.disable_torque()
|
||||
was_paused = False
|
||||
was_correction_active = False
|
||||
waiting_for_takeover = False
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Detect transition to paused state - smooth move teleop to robot position
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press START to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
|
||||
# Wait for start button before enabling correction mode
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Start pressed - enabling teleop control...")
|
||||
teleop.disable_torque()
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
was_correction_active = True
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling - record correction data
|
||||
robot_action = teleop.get_action()
|
||||
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
# Waiting for START - policy stopped, no recording, robot holds position
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
# Paused and user acknowledged - teleop tracks robot position, don't record
|
||||
robot_action = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
teleop.send_feedback(robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
else:
|
||||
# Normal policy execution - record (teleop is free/idle)
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
# Ensure teleoperator torque is disabled at end
|
||||
teleop.disable_torque()
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
):
|
||||
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET - Moving teleop to robot position...")
|
||||
print("=" * 65)
|
||||
|
||||
# Enter reset mode
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
# First move teleop to match robot position to avoid sudden jumps
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
# Stage 1: Wait for user to press start to enable teleoperation
|
||||
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
# Stage 2: Enable teleop and let user move robot to starting position
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - move robot to starting position")
|
||||
print(" Press any key/pedal to start next episode")
|
||||
|
||||
# Wait for user to signal ready for next episode
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
action = teleop.get_action()
|
||||
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
|
||||
for key in action:
|
||||
if "gripper" in key:
|
||||
action[key] = -0.65 * action[key]
|
||||
robot.send_action(action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
# Exit reset mode and clear flags for next episode
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_collection_openarms")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC (Recovery and Correction) Data Collection - OpenArms")
|
||||
print("=" * 65)
|
||||
print(" Policy runs autonomously until you intervene.")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy (teleop tracks robot, no recording)")
|
||||
print(" c - Take control (start correction, recording)")
|
||||
print(" → - End episode (save)")
|
||||
print(" ← - Re-record episode")
|
||||
print(" ESC - Stop session and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
|
||||
|
||||
stats = rac_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
# Reset between episodes
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
rac_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,902 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC.
|
||||
|
||||
This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution.
|
||||
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
|
||||
despite high inference latency by asynchronously generating action chunks.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously with RTC (teleop is idle/free)
|
||||
2. Press SPACE to pause - teleop moves to match robot position
|
||||
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection_openarms_rtc.py \
|
||||
--robot.port_right=can0 \
|
||||
--robot.port_left=can1 \
|
||||
--teleop.port_right=/dev/ttyUSB0 \
|
||||
--teleop.port_left=/dev/ttyUSB1 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_openarms_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class RaCRTCDatasetConfig:
|
||||
repo_id: str = "lerobot/rac_openarms_rtc"
|
||||
single_task: str = "default task"
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 500
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
streaming_encoding: bool = True
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCRTCConfig:
|
||||
robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
))
|
||||
teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig(
|
||||
port_left="/dev/ttyUSB1",
|
||||
port_right="/dev/ttyUSB0",
|
||||
))
|
||||
dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig)
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
))
|
||||
|
||||
interpolation: bool = True
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Torch compile is disabled by default for real-time inference
|
||||
# First inference with compile takes minutes to compile kernels
|
||||
use_torch_compile: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py)
|
||||
# ============================================================================
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self.robot.robot_type
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Keyboard/Pedal Listeners
|
||||
# ============================================================================
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes # noqa: F401
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A"
|
||||
KEY_RIGHT = "KEY_C"
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize # noqa: F401
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if code == KEY_RIGHT:
|
||||
if events["correction_active"]:
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
events["policy_paused"] = True
|
||||
elif code == KEY_LEFT:
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RTC Inference Thread (from evaluate_with_rtc.py)
|
||||
# ============================================================================
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy,
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: RaCRTCConfig,
|
||||
):
|
||||
"""Background thread that generates action chunks using RTC."""
|
||||
try:
|
||||
logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========")
|
||||
logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
policy_device = policy.config.device
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
wait_logged = False
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
if not wait_logged:
|
||||
logger.info("[RTC] Waiting for policy_active...")
|
||||
wait_logged = True
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
wait_logged = False
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
if action_queue is None:
|
||||
logger.warning("[RTC] queue_holder['queue'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
obs_filtered = obs_holder.get("obs")
|
||||
if obs_filtered is None:
|
||||
logger.warning("[RTC] obs_holder['obs'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
qsize = action_queue.qsize()
|
||||
if qsize <= get_actions_threshold:
|
||||
try:
|
||||
if inference_count == 0:
|
||||
logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference)
|
||||
|
||||
inference_count += 1
|
||||
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Inference error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[RTC] Inference thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] THREAD CRASHED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Rollout Loop
|
||||
# ============================================================================
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rtc_rollout_loop(
|
||||
robot: RobotWrapper,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: RaCRTCConfig,
|
||||
queue_holder: dict,
|
||||
obs_holder: dict, # Main loop writes obs here for RTC thread to read
|
||||
policy_active: Event,
|
||||
hw_features: dict,
|
||||
) -> dict:
|
||||
"""RaC rollout loop with RTC for smooth policy execution."""
|
||||
fps = cfg.dataset.fps
|
||||
single_task = cfg.dataset.single_task
|
||||
control_time_s = cfg.dataset.episode_time_s
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
|
||||
# Reset policy state
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
streaming = dataset._streaming_encoder is not None
|
||||
frame_buffer = [] if not streaming else None
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
teleop.disable_torque()
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Action keys for converting tensor to dict
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
# Interpolation state
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
|
||||
if cfg.interpolation:
|
||||
control_interval = 1.0 / (fps * 2) # 2x rate
|
||||
else:
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
robot_action = {}
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# State transition: entering paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear() # Stop RTC inference
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press 'c' to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
# Reset interpolation
|
||||
prev_action = None
|
||||
interpolated_actions = []
|
||||
interp_idx = 0
|
||||
|
||||
# Wait for takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Taking control...")
|
||||
teleop.disable_torque()
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Get observation (ONLY the main loop reads from robot!)
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
# Share observation with RTC thread (thread reads, main loop writes)
|
||||
obs_holder["obs"] = obs_filtered
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling
|
||||
robot_action = teleop.get_action()
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
if streaming:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
teleop.send_feedback(robot_pos)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
if not policy_active.is_set():
|
||||
policy_active.set()
|
||||
logger.info("[ROLLOUT] Policy activated, waiting for first actions...")
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
|
||||
# Get action from queue (with interpolation)
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get() if action_queue else None
|
||||
|
||||
# Log queue status periodically
|
||||
if stats["autonomous_frames"] == 0 and new_action is None:
|
||||
qsize = action_queue.qsize() if action_queue else -1
|
||||
if timestamp < 0.5 or int(timestamp * 10) % 10 == 0:
|
||||
logger.info(f"[ROLLOUT] Waiting for actions... queue_size={qsize}, obs_set={obs_holder.get('obs') is not None}")
|
||||
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if stats["autonomous_frames"] == 0:
|
||||
logger.info(f"[ROLLOUT] Got first action! Starting robot motion.")
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
robot_action = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
robot_action[key] = action_to_send[i].item()
|
||||
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record at original fps
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
if streaming:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if cfg.display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_time = control_interval - dt
|
||||
if sleep_time > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
teleop.disable_torque()
|
||||
|
||||
if not streaming:
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET")
|
||||
print("=" * 65)
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
print(" Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - press any key/pedal to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
for key in action:
|
||||
if "gripper" in key:
|
||||
action[key] = -0.65 * action[key]
|
||||
robot.send_action(action)
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Entry Point
|
||||
# ============================================================================
|
||||
|
||||
@parser.wrap()
|
||||
def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function with RTC."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_rtc_collection_openarms")
|
||||
|
||||
robot_raw = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot_raw.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot_raw.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if cfg.dataset.streaming_encoding:
|
||||
dataset.start_streaming_encoder()
|
||||
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot_raw.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
)
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Override compile_model for real-time inference (first compile takes minutes)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
if cfg.policy.type in ["pi05", "pi0"]:
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
logger.info(f"Set compile_model={cfg.use_torch_compile} for real-time inference")
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Setup preprocessor/postprocessor
|
||||
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
# Connect robot and wrap for thread safety
|
||||
robot_raw.connect()
|
||||
robot = RobotWrapper(robot_raw)
|
||||
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
# Shared state holders (main loop writes, RTC thread reads)
|
||||
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
||||
obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs
|
||||
|
||||
# Start RTC inference thread
|
||||
# NOTE: Thread does NOT access robot directly - reads from obs_holder
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
obs_holder, # Thread reads obs from here (set by main loop)
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
rtc_thread.start()
|
||||
logger.info("Started RTC inference thread")
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC Data Collection with RTC")
|
||||
print("=" * 65)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" FPS: {cfg.dataset.fps}")
|
||||
print(f" Interpolation: {cfg.interpolation}")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy")
|
||||
print(" c - Take control")
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
# Fresh action queue per episode (update holder so thread sees it)
|
||||
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
||||
|
||||
logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}")
|
||||
|
||||
stats = rac_rtc_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
queue_holder=queue_holder,
|
||||
obs_holder=obs_holder,
|
||||
policy_active=policy_active,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
t_save_start = time.perf_counter()
|
||||
dataset.save_episode()
|
||||
logging.info(f"[RaC] save_episode total: {time.perf_counter() - t_save_start:.2f}s")
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot_raw.is_connected:
|
||||
robot_raw.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
register_third_party_plugins()
|
||||
rac_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
@@ -11,7 +12,6 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -40,9 +40,8 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -84,26 +83,24 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
stats = algorithm.update(batch_iter())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -147,15 +144,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from huggingface_hub import HfApi, list_datasets
|
||||
|
||||
api = HfApi()
|
||||
datasets = list_datasets(author="lerobot-data-collection")
|
||||
print('"[', end="")
|
||||
i=0
|
||||
for dataset in datasets:
|
||||
if "three-folds-dataset" in dataset.id:
|
||||
print("'" + dataset.id + "',", end="")
|
||||
print(']"',)
|
||||
+3
-3
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -211,15 +211,3 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm name registered in RLAlgorithmConfig registry
|
||||
algorithm: str = "sac"
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline"
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer
|
||||
online_ratio: float = 0.5
|
||||
|
||||
# RL trainer iterator
|
||||
async_prefetch: bool = True
|
||||
queue_size: int = 2
|
||||
|
||||
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
|
||||
class RTCTrainingDelayDistribution(str, Enum):
|
||||
UNIFORM = "UNIFORM"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -13,6 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
@@ -227,19 +231,20 @@ def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_si
|
||||
return img[:, ::downsample_factor, ::downsample_factor]
|
||||
|
||||
|
||||
def _load_single_image(path: str) -> np.ndarray:
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
return auto_downsample_height_width(img)
|
||||
|
||||
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
paths = [image_paths[idx] for idx in sampled_indices]
|
||||
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(paths))) as pool:
|
||||
loaded = list(pool.map(_load_single_image, paths))
|
||||
|
||||
images = np.empty((len(loaded), *loaded[0].shape), dtype=np.uint8)
|
||||
for i, img in enumerate(loaded):
|
||||
images[i] = img
|
||||
|
||||
return images
|
||||
@@ -504,27 +509,46 @@ def compute_episode_stats(
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue
|
||||
|
||||
def _compute_single_feature_stats(key, data):
|
||||
t0 = time.perf_counter()
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
kd = True
|
||||
else:
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
kd = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
stats = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=kd, quantile_list=quantile_list)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
stats = {k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in stats.items()}
|
||||
|
||||
dt = time.perf_counter() - t0
|
||||
if dt > 0.1:
|
||||
logging.info(f"[compute_episode_stats] {key} ({features[key]['dtype']}): {dt:.2f}s")
|
||||
return key, stats
|
||||
|
||||
# Split into image/video features (heavy I/O) and numeric features (fast)
|
||||
image_keys = [(k, d) for k, d in episode_data.items()
|
||||
if k in features and features[k]["dtype"] in ["image", "video"]]
|
||||
numeric_keys = [(k, d) for k, d in episode_data.items()
|
||||
if k in features and features[k]["dtype"] not in ["image", "video", "string"]]
|
||||
|
||||
# Run image features in parallel (I/O bound)
|
||||
if image_keys:
|
||||
with ThreadPoolExecutor(max_workers=len(image_keys)) as pool:
|
||||
futures = [pool.submit(_compute_single_feature_stats, k, d) for k, d in image_keys]
|
||||
for f in futures:
|
||||
key, stats = f.result()
|
||||
ep_stats[key] = stats
|
||||
|
||||
# Numeric features are fast — run sequentially
|
||||
for k, d in numeric_keys:
|
||||
_, stats = _compute_single_feature_stats(k, d)
|
||||
ep_stats[k] = stats
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
@@ -1522,6 +1522,122 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
delta_action: bool = False,
|
||||
delta_exclude_joints: list[str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
delta_action: If True, compute action stats as delta (action - state).
|
||||
Useful when training with use_delta_actions=True so normalization matches.
|
||||
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||
action feature names to build the mask. Default: ["gripper"].
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
features = dataset.meta.features
|
||||
numeric_features = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] != "string"
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
if delta_exclude_joints is None:
|
||||
delta_exclude_joints = ["gripper"]
|
||||
action_names = features["action"].get("names")
|
||||
if action_names is not None:
|
||||
exclude = set(delta_exclude_joints)
|
||||
delta_mask = [n not in exclude for n in action_names]
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||
features_to_compute = {"action": features["action"]}
|
||||
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||
else:
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
# Also need state for delta computation even though we don't recompute state stats
|
||||
needs_state = delta_mask is not None
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
# Load state for delta even if we're not computing state stats
|
||||
if needs_state and "observation.state" in ep_df.columns:
|
||||
state_values = ep_df["observation.state"].values
|
||||
if hasattr(state_values[0], "__len__"):
|
||||
states = np.stack(state_values)
|
||||
else:
|
||||
states = np.array(state_values)
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(states).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -18,22 +18,30 @@ import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import os
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.compute_stats import (
|
||||
RunningQuantileStats,
|
||||
aggregate_stats,
|
||||
auto_downsample_height_width,
|
||||
compute_episode_stats,
|
||||
)
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
@@ -68,6 +76,7 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
@@ -79,7 +88,6 @@ from lerobot.datasets.video_utils import (
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -420,8 +428,10 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
logging.info(f"[meta.save_episode] aggregate+write_stats: {time.perf_counter() - t0:.2f}s")
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
@@ -544,13 +554,11 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
) -> Path:
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -569,7 +577,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -656,7 +663,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
@@ -682,13 +689,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -700,7 +702,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = vcodec
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -708,6 +709,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
self._streaming_encoder = None
|
||||
self._running_video_stats = {}
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -938,30 +941,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -1053,12 +1043,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -1102,6 +1090,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
|
||||
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
|
||||
"""
|
||||
if self._streaming_encoder:
|
||||
self._streaming_encoder.close()
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
|
||||
@@ -1153,6 +1143,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
if frame_index == 0 and self._streaming_encoder:
|
||||
self._streaming_encoder.start_episode(self.meta.video_keys, self.root)
|
||||
self._init_running_video_stats()
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
@@ -1166,14 +1159,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
if self._streaming_encoder and self.features[key]["dtype"] == "video":
|
||||
self._feed_streaming_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None)
|
||||
else:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
@@ -1224,53 +1221,50 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
if self._streaming_encoder:
|
||||
filtered = {k: v for k, v in episode_buffer.items() if k not in self.meta.video_keys}
|
||||
ep_stats = compute_episode_stats(filtered, self.features)
|
||||
for key in self.meta.video_keys:
|
||||
stats = self._running_video_stats[key].get_statistics()
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else (v.reshape(-1, 1, 1) / 255.0)
|
||||
for k, v in stats.items()
|
||||
}
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
t_stats = time.perf_counter() - t0
|
||||
|
||||
t0 = time.perf_counter()
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
t_save_data = time.perf_counter() - t0
|
||||
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
# num_cameras * num_threads = (total_cpu -1)
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
t0 = time.perf_counter()
|
||||
if has_video_keys and self._streaming_encoder:
|
||||
video_paths = self._streaming_encoder.finish_episode()
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_paths[video_key]))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
|
||||
for video_key, video_path in zip(self.meta.video_keys, video_paths):
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
|
||||
t_video = time.perf_counter() - t0
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
t0 = time.perf_counter()
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
t_meta = time.perf_counter() - t0
|
||||
|
||||
logging.info(
|
||||
f"[save_episode] ep={episode_index} frames={episode_length} | "
|
||||
f"stats={t_stats:.2f}s data={t_save_data:.2f}s video={t_video:.2f}s meta={t_meta:.2f}s "
|
||||
f"total={t_stats + t_save_data + t_video + t_meta:.2f}s"
|
||||
)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
@@ -1438,6 +1432,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index: int,
|
||||
temp_path: Path | None = None,
|
||||
) -> dict:
|
||||
t0 = time.perf_counter()
|
||||
# Encode episode frames into a temporary video
|
||||
if temp_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
@@ -1511,9 +1506,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
|
||||
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
save_time = time.perf_counter() - t0
|
||||
rate = ep_duration_in_s / save_time if save_time > 0 else float("inf")
|
||||
logging.info(
|
||||
f"[save_episode_video] {video_key} ep={episode_index} "
|
||||
f"save={save_time:.2f}s video_dur={ep_duration_in_s:.1f}s "
|
||||
f"size={ep_size_in_mb:.1f}MB rate={rate:.2f}x realtime"
|
||||
)
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
if self._streaming_encoder:
|
||||
self._streaming_encoder.stop_episode()
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1522,7 +1526,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.image_keys:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
@@ -1555,13 +1559,66 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def start_streaming_encoder(self):
|
||||
"""Enable streaming video encoding for recording."""
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self._streaming_encoder = StreamingVideoEncoder(fps=self.fps)
|
||||
self._running_video_stats = {}
|
||||
|
||||
def _init_running_video_stats(self):
|
||||
self._running_video_stats = {key: RunningQuantileStats() for key in self.meta.video_keys}
|
||||
|
||||
def _feed_streaming_frame(self, key: str, image) -> None:
|
||||
"""Feed image to streaming encoder and accumulate running stats."""
|
||||
if isinstance(image, np.ndarray):
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3, 4):
|
||||
img_chw = image
|
||||
else:
|
||||
img_chw = image.transpose(2, 0, 1)
|
||||
else:
|
||||
img_chw = np.array(image).transpose(2, 0, 1)
|
||||
|
||||
self._streaming_encoder.feed_frame(key, image)
|
||||
img_ds = auto_downsample_height_width(img_chw)
|
||||
c, h, w = img_ds.shape
|
||||
self._running_video_stats[key].update(
|
||||
img_ds.transpose(1, 2, 0).reshape(-1, c).astype(np.float64)
|
||||
)
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
|
||||
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
|
||||
temp_paths = []
|
||||
img_dirs = []
|
||||
for video_key in video_keys:
|
||||
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
|
||||
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
|
||||
fps = [self.fps]*len(video_keys)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
|
||||
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
|
||||
encode_time = time.perf_counter() - t0
|
||||
|
||||
n_frames = len(list(img_dirs[0].glob("*"))) if img_dirs and img_dirs[0].exists() else 0
|
||||
video_duration_s = n_frames / self.fps if n_frames > 0 else 0
|
||||
rate = video_duration_s / encode_time if encode_time > 0 else float("inf")
|
||||
logging.info(
|
||||
f"[encode_videos] ep={episode_index} keys={len(video_keys)} "
|
||||
f"encode={encode_time:.2f}s video_dur={video_duration_s:.1f}s "
|
||||
f"rate={rate:.2f}x realtime"
|
||||
)
|
||||
|
||||
for img_dir in img_dirs:
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
return temp_paths
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1577,11 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1598,7 +1653,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_writer = None
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1616,6 +1670,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._current_file_start_frame = None
|
||||
obj._streaming_encoder = None
|
||||
obj._running_video_stats = {}
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
obj._streaming_encoder = StreamingVideoEncoder(fps=fps)
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
|
||||
@@ -122,9 +122,19 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -16,16 +16,18 @@
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -310,7 +312,7 @@ def encode_video_frames(
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
overwrite: bool = True,
|
||||
preset: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
@@ -355,6 +357,9 @@ def encode_video_frames(
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
|
||||
#TEMPORARY FIX
|
||||
video_options["preset"] = "12"
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
@@ -397,6 +402,141 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
_DONE = object()
|
||||
|
||||
|
||||
class _CameraEncoder:
|
||||
"""Encodes frames for one camera in a daemon thread."""
|
||||
|
||||
def __init__(self, video_path, fps, vcodec, pix_fmt, g, crf):
|
||||
self.video_path = Path(video_path)
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.queue = queue.Queue()
|
||||
self._thread = None
|
||||
self._cancelled = False
|
||||
|
||||
def start(self):
|
||||
self.video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._thread = Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def finish(self) -> Path:
|
||||
self.queue.put(_DONE)
|
||||
self._thread.join(timeout=120)
|
||||
return self.video_path
|
||||
|
||||
def cancel(self):
|
||||
self._cancelled = True
|
||||
while not self.queue.empty():
|
||||
try:
|
||||
self.queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
self.queue.put(_DONE)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
if self.video_path.parent.exists():
|
||||
shutil.rmtree(self.video_path.parent, ignore_errors=True)
|
||||
|
||||
def _run(self):
|
||||
options = {}
|
||||
if self.g is not None:
|
||||
options["g"] = str(self.g)
|
||||
if self.crf is not None:
|
||||
options["crf"] = str(self.crf)
|
||||
if self.vcodec == "libsvtav1":
|
||||
options["preset"] = "12"
|
||||
|
||||
output = None
|
||||
output_stream = None
|
||||
try:
|
||||
while True:
|
||||
data = self.queue.get()
|
||||
if data is _DONE or self._cancelled:
|
||||
break
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.ndim == 3 and data.shape[0] in (1, 3, 4):
|
||||
data = data.transpose(1, 2, 0)
|
||||
pil = Image.fromarray(data.astype(np.uint8)).convert("RGB")
|
||||
else:
|
||||
pil = data.convert("RGB")
|
||||
|
||||
if output is None:
|
||||
w, h = pil.size
|
||||
output = av.open(str(self.video_path), "w")
|
||||
output_stream = output.add_stream(self.vcodec, self.fps, options=options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = w
|
||||
output_stream.height = h
|
||||
|
||||
pkt = output_stream.encode(av.VideoFrame.from_image(pil))
|
||||
if pkt:
|
||||
output.mux(pkt)
|
||||
|
||||
if output_stream and not self._cancelled:
|
||||
pkt = output_stream.encode()
|
||||
if pkt:
|
||||
output.mux(pkt)
|
||||
except Exception as e:
|
||||
logging.error(f"[StreamingEncoder] {e}")
|
||||
finally:
|
||||
if output:
|
||||
output.close()
|
||||
|
||||
|
||||
class StreamingVideoEncoder:
|
||||
"""Encodes video on-the-fly using one background thread per camera.
|
||||
|
||||
PyAV releases the GIL during encoding, so Python threads give true
|
||||
parallelism for the CPU-intensive codec work. The queue is unbounded
|
||||
so feed_frame never blocks the caller (teleop thread always has priority).
|
||||
"""
|
||||
|
||||
def __init__(self, fps, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30):
|
||||
self.fps = fps
|
||||
self._vcodec = vcodec
|
||||
self._pix_fmt = pix_fmt
|
||||
self._g = g
|
||||
self._crf = crf
|
||||
self._encoders: dict[str, _CameraEncoder] = {}
|
||||
|
||||
def start_episode(self, video_keys, temp_dir):
|
||||
self.stop_episode()
|
||||
for key in video_keys:
|
||||
path = Path(tempfile.mkdtemp(dir=temp_dir)) / f"{key}_stream.mp4"
|
||||
enc = _CameraEncoder(path, self.fps, self._vcodec, self._pix_fmt, self._g, self._crf)
|
||||
enc.start()
|
||||
self._encoders[key] = enc
|
||||
|
||||
def feed_frame(self, video_key, image):
|
||||
"""Non-blocking: put frame on unbounded queue (never blocks caller)."""
|
||||
enc = self._encoders.get(video_key)
|
||||
if enc:
|
||||
enc.queue.put(image)
|
||||
|
||||
def finish_episode(self) -> dict[str, Path]:
|
||||
"""Flush all encoders, wait for completion, return {key: video_path}."""
|
||||
paths = {}
|
||||
for key, enc in self._encoders.items():
|
||||
paths[key] = enc.finish()
|
||||
self._encoders.clear()
|
||||
return paths
|
||||
|
||||
def stop_episode(self):
|
||||
"""Cancel current episode encoding (for re-record)."""
|
||||
for enc in self._encoders.values():
|
||||
enc.cancel()
|
||||
self._encoders.clear()
|
||||
|
||||
def close(self):
|
||||
self.stop_episode()
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
|
||||
@@ -18,4 +18,7 @@ from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
|
||||
MotorsBusBase,
|
||||
SerialMotorsBus,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
@@ -33,7 +33,6 @@ class MotorType(IntEnum):
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
@@ -41,7 +40,6 @@ class ControlMode(IntEnum):
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
@@ -90,8 +88,7 @@ class MotorVariable(IntEnum):
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
@@ -147,10 +144,10 @@ MODEL_RESOLUTION = {
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
@@ -163,6 +160,9 @@ DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
@@ -182,14 +182,14 @@ OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
|
||||
@@ -22,7 +22,8 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
@@ -202,9 +203,9 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -17,7 +17,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
@@ -164,7 +165,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
self._assert_same_firmware()
|
||||
#self._assert_same_firmware()
|
||||
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
if self.protocol_version == 0:
|
||||
@@ -297,11 +298,11 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -470,6 +470,13 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names for delta_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
def mlp_func(action_time_emb):
|
||||
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
|
||||
)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
delta_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -50,8 +50,16 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
delta_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||
)
|
||||
|
||||
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
continuous_actions = to_absolute_actions(
|
||||
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy configuration.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTokenConfig:
|
||||
"""Configuration for the RL-token encoder/decoder transformer."""
|
||||
|
||||
input_dim: int = 2048
|
||||
rl_token_dim: int = 2048
|
||||
num_encoder_layers: int = 2
|
||||
num_decoder_layers: int = 2
|
||||
num_heads: int = 8
|
||||
ff_dim: int = 2048
|
||||
dropout: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTActorConfig:
|
||||
"""Configuration for the lightweight RL actor MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
std: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTCriticConfig:
|
||||
"""Configuration for the RLT critic MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTConfig(PreTrainedConfig):
|
||||
"""Configuration for the RLT (RL Token) policy.
|
||||
|
||||
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
|
||||
a lightweight actor-critic head using the RL token as state representation.
|
||||
The frozen VLA also provides reference action chunks that the actor refines.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {"min": [0.0], "max": [1.0]},
|
||||
ACTION: {"min": [0.0], "max": [1.0]},
|
||||
}
|
||||
)
|
||||
|
||||
# ── Device ──
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
|
||||
# ── VLA backbone ──
|
||||
vla_checkpoint: str | None = None
|
||||
|
||||
# ── RL-token ──
|
||||
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
|
||||
|
||||
# ── Actor / Critic heads ──
|
||||
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
|
||||
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
vla_chunk_size: int = 50
|
||||
|
||||
# ── Training parameters ──
|
||||
online_steps: int = 50000
|
||||
offline_steps: int = 5000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 500
|
||||
warmup_steps: int = 500
|
||||
async_prefetch: bool = False
|
||||
|
||||
# ── Algorithm hyperparameters ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
discount: float = 0.99
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
tau: float = 0.005
|
||||
clip_grad_norm: float = 10.0
|
||||
num_critics: int = 2
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
chunk_stride: int = 2
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
# ── Distributed ──
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
def get_optimizer_preset(self):
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,318 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy networks.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
|
||||
Architecture:
|
||||
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
|
||||
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
|
||||
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
|
||||
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
|
||||
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlt.configuration_rlt import RLTConfig
|
||||
|
||||
# ── Building blocks ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Simple feedforward network with ReLU activations."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
prev = input_dim
|
||||
for h in hidden_dims:
|
||||
layers.append(nn.Linear(prev, h))
|
||||
layers.append(nn.ReLU())
|
||||
prev = h
|
||||
layers.append(nn.Linear(prev, output_dim))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# ── RL Token Encoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenEncoder(nn.Module):
|
||||
"""Compress VLA token embeddings into a single RL token via a small transformer.
|
||||
|
||||
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
|
||||
through transformer encoder layers, and returns the output at the ``e_rl``
|
||||
position as the RL token ``z_rl``.
|
||||
|
||||
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
rl_token_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.rl_token_dim = rl_token_dim
|
||||
|
||||
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
|
||||
|
||||
if input_dim != rl_token_dim:
|
||||
self.input_proj = nn.Linear(input_dim, rl_token_dim)
|
||||
else:
|
||||
self.input_proj = nn.Identity()
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=rl_token_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
def forward(self, z_vla: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_vla: VLA token embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
|
||||
"""
|
||||
batch_size = z_vla.shape[0]
|
||||
e_rl = self.e_rl.expand(batch_size, -1, -1)
|
||||
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
|
||||
seq = self.input_proj(seq)
|
||||
out = self.transformer(seq)
|
||||
z_rl = out[:, -1, :] # output at e_rl position
|
||||
return z_rl
|
||||
|
||||
|
||||
# ── RL Token Decoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenDecoder(nn.Module):
|
||||
"""Autoregressively reconstruct VLA embeddings from z_rl.
|
||||
|
||||
Used only during Stage 1 (offline RL-token training).
|
||||
|
||||
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rl_token_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
|
||||
if rl_token_dim != output_dim:
|
||||
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
|
||||
else:
|
||||
self.rl_proj = nn.Identity()
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=output_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
||||
self.output_head = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_rl: RL token, shape ``(B, D_rl)``.
|
||||
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
Reconstructed embeddings, shape ``(B, M, D)``.
|
||||
"""
|
||||
seq_len = z_vla_stopped.shape[1]
|
||||
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
|
||||
|
||||
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
|
||||
|
||||
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
|
||||
|
||||
decoded = self.transformer(
|
||||
tgt=target,
|
||||
memory=z_rl_proj,
|
||||
tgt_mask=causal_mask,
|
||||
)
|
||||
return self.output_head(decoded) # (B, M, D)
|
||||
|
||||
|
||||
# ── Actor ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTActor(nn.Module):
|
||||
"""Lightweight actor that refines VLA reference action chunks.
|
||||
|
||||
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
|
||||
|
||||
The actor is conditioned on both the RL state and the VLA's proposed action
|
||||
chunk, acting as a "VLA-guided action editor".
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
|
||||
super().__init__()
|
||||
input_dim = state_dim + action_chunk_dim
|
||||
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
|
||||
self.log_std = math.log(std)
|
||||
|
||||
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
|
||||
"""Return the mean action chunk.
|
||||
|
||||
Args:
|
||||
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
|
||||
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
|
||||
|
||||
Returns:
|
||||
Refined action chunk (mean), shape ``(B, C*d)``.
|
||||
"""
|
||||
x = torch.cat([state, ref_action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Sample an action and return (action, log_prob)."""
|
||||
mean = self.forward(state, ref_action_chunk)
|
||||
std = math.exp(self.log_std)
|
||||
noise = torch.randn_like(mean) * std
|
||||
action = mean + noise
|
||||
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
|
||||
std * math.sqrt(2 * math.pi)
|
||||
)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
# ── Policy (inference bundle) ────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTPolicy(PreTrainedPolicy):
|
||||
"""RLT policy — bundles the RL-token encoder and actor for inference.
|
||||
|
||||
The frozen VLA backbone is **not** part of this module; it is loaded
|
||||
separately and its embeddings / reference actions are passed in via the
|
||||
observation dict (populated by the actor process or a preprocessor).
|
||||
|
||||
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
|
||||
and optimizers. This class only contains what is needed for ``select_action``.
|
||||
"""
|
||||
|
||||
name = "rlt"
|
||||
config_class = RLTConfig
|
||||
|
||||
def __init__(self, config: RLTConfig, dataset_stats=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
action_dim = config.output_features["action"].shape[0]
|
||||
action_chunk_dim = config.chunk_size * action_dim
|
||||
prop_feature = config.input_features.get("observation.state", None)
|
||||
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
|
||||
|
||||
state_dim = config.rl_token.rl_token_dim + proprioception_dim
|
||||
|
||||
# RL-token encoder (frozen after Stage 1)
|
||||
self.rl_token_encoder = RLTokenEncoder(
|
||||
input_dim=config.rl_token.input_dim,
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
num_layers=config.rl_token.num_encoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# RL-token decoder (used only during Stage 1 training)
|
||||
self.rl_token_decoder = RLTokenDecoder(
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
output_dim=config.rl_token.input_dim,
|
||||
num_layers=config.rl_token.num_decoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# Actor MLP
|
||||
self.actor = RLTActor(
|
||||
state_dim=state_dim,
|
||||
action_chunk_dim=action_chunk_dim,
|
||||
hidden_dims=config.actor.hidden_dims,
|
||||
std=config.actor.std,
|
||||
)
|
||||
|
||||
self._action_dim = action_dim
|
||||
self._action_chunk_dim = action_chunk_dim
|
||||
self._state_dim = state_dim
|
||||
self._proprioception_dim = proprioception_dim
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a refined action chunk given an observation.
|
||||
|
||||
Expects the observation dict to contain:
|
||||
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
|
||||
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
|
||||
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
|
||||
|
||||
Returns:
|
||||
Action chunk tensor of shape ``(C*d,)``.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
vla_emb = batch["observation.vla_embeddings"]
|
||||
if vla_emb.dim() == 2:
|
||||
vla_emb = vla_emb.unsqueeze(0)
|
||||
|
||||
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
|
||||
|
||||
parts = [z_rl]
|
||||
if "observation.state" in batch and self._proprioception_dim > 0:
|
||||
prop = batch["observation.state"]
|
||||
if prop.dim() == 1:
|
||||
prop = prop.unsqueeze(0)
|
||||
parts.append(prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
|
||||
ref = batch["observation.reference_action"]
|
||||
if ref.dim() == 1:
|
||||
ref = ref.unsqueeze(0)
|
||||
|
||||
action = self.actor(state, ref)
|
||||
return action.squeeze(0)
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
@@ -23,7 +23,7 @@ Based on:
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,3 +53,22 @@ class RTCConfig:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCTrainingConfig:
|
||||
"""Configuration for training-time RTC action prefix conditioning."""
|
||||
|
||||
enabled: bool = False
|
||||
min_delay: int = 0
|
||||
max_delay: int = 0
|
||||
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
|
||||
exp_decay: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.min_delay < 0:
|
||||
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
||||
if self.max_delay < self.min_delay:
|
||||
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
|
||||
if self.exp_decay <= 0:
|
||||
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
|
||||
|
||||
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
|
||||
if cfg.max_delay == cfg.min_delay:
|
||||
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
||||
|
||||
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
||||
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
|
||||
|
||||
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
||||
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
||||
probs = weights / weights.sum()
|
||||
samples = torch.multinomial(probs, batch_size, replacement=True)
|
||||
return delay_values[samples]
|
||||
|
||||
|
||||
def apply_rtc_training_time(
|
||||
time: torch.Tensor, delay: torch.Tensor, seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device = time.device
|
||||
delay = torch.clamp(delay, max=seq_len)
|
||||
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
|
||||
time_tokens = time[:, None].expand(-1, seq_len)
|
||||
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
|
||||
postfix_mask = ~prefix_mask
|
||||
return time_tokens, postfix_mask
|
||||
|
||||
|
||||
def masked_mean(
|
||||
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return losses.mean(dim=reduce_dims)
|
||||
|
||||
mask = mask.to(dtype=losses.dtype)
|
||||
while mask.dim() < losses.dim():
|
||||
mask = mask.unsqueeze(-1)
|
||||
masked = losses * mask
|
||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||
return masked.sum(dim=reduce_dims) / denom
|
||||
|
||||
|
||||
def apply_training_time_rtc_inference(
|
||||
x_t: torch.Tensor,
|
||||
time: float,
|
||||
inference_delay: int | None,
|
||||
prev_chunk_left_over: torch.Tensor | None,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply training-time RTC conditioning during inference.
|
||||
|
||||
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
|
||||
|
||||
At each denoising step:
|
||||
1. Replace prefix positions in x_t with ground truth from previous chunk
|
||||
2. Create per-token timesteps with 1.0 for prefix positions
|
||||
|
||||
Args:
|
||||
x_t: Current noisy actions (B, T, D)
|
||||
time: Current flow matching timestep (scalar)
|
||||
inference_delay: Number of prefix actions to condition on
|
||||
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
|
||||
chunk_size: Total chunk size T
|
||||
|
||||
Returns:
|
||||
x_t_conditioned: x_t with prefix replaced by previous actions
|
||||
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
device = x_t.device
|
||||
|
||||
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
|
||||
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
|
||||
return x_t, time_scalar
|
||||
|
||||
delay = min(inference_delay, chunk_size)
|
||||
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
|
||||
|
||||
x_t_conditioned = torch.where(
|
||||
prefix_mask[:, :, None].expand_as(x_t),
|
||||
prev_chunk_left_over[:, :chunk_size, :],
|
||||
x_t,
|
||||
)
|
||||
|
||||
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
|
||||
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
|
||||
|
||||
return x_t_conditioned, time_per_token
|
||||
@@ -15,11 +15,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -47,13 +52,20 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self.encoder = SACObservationEncoder(config)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_discrete_critic()
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": [self.actor.parameters()],
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
@@ -71,9 +83,10 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
|
||||
observations_features = None
|
||||
if self.encoder.has_images:
|
||||
observations_features = self.encoder.get_cached_image_features(batch)
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -84,35 +97,372 @@ class SACPolicy(
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Actor forward pass."""
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
"""Compute the loss for the given model
|
||||
|
||||
def _init_actor(self, continuous_action_dim: int) -> None:
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder,
|
||||
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=False,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
def _init_discrete_critic(self) -> None:
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder,
|
||||
input_dim=self.encoder.output_dim,
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
@@ -15,5 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_wall_x import WallXConfig
|
||||
from .modeling_wall_x import WallXPolicy
|
||||
from .processor_wall_x import make_wall_x_pre_post_processors
|
||||
|
||||
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
|
||||
|
||||
@@ -28,7 +28,14 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
@@ -44,7 +51,6 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -88,7 +94,6 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
@@ -99,6 +104,8 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeltaActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -128,6 +135,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_delta_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,54 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class DeltaActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the delta conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to delta).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted deltas are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired DeltaActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.delta_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
|
||||
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.delta_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.delta_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.delta_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,37 +312,6 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -131,15 +131,6 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -158,7 +149,6 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -208,7 +198,6 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -219,7 +208,6 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
@@ -343,11 +331,9 @@ class _NormalizationMixin:
|
||||
)
|
||||
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
return tensor * (std + 1e-6) + mean
|
||||
return (tensor - mean) / (std + 1e-6)
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min", None)
|
||||
@@ -379,11 +365,7 @@ class _NormalizationMixin:
|
||||
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
|
||||
)
|
||||
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
denom = q99 - q01 + 1e-6
|
||||
if inverse:
|
||||
return (tensor + 1.0) * denom / 2.0 + q01
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
+19
-9
@@ -61,7 +61,7 @@ from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
@@ -248,16 +248,16 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
policy = make_policy(
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
# preprocessor, postprocessor = None, None
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -288,6 +288,7 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
@@ -648,12 +649,12 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
"""Load the latest policy weights from the learner."""
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dicts = bytes_to_state_dict(bytes_state_dict)
|
||||
|
||||
# TODO: check encoder parameter synchronization possible issues:
|
||||
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
|
||||
# instead of the updated encoder params from critic (which is optimized separately)
|
||||
@@ -663,9 +664,18 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
state_dicts = move_state_dict_to_device(state_dicts, device=device)
|
||||
policy.load_state_dict(state_dicts)
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
RLAlgorithm,
|
||||
RLAlgorithmConfig,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm(
|
||||
policy: torch.nn.Module,
|
||||
policy_cfg,
|
||||
*,
|
||||
algorithm_name: str,
|
||||
) -> RLAlgorithm:
|
||||
"""Construct an :class:`RLAlgorithm` from a policy and its config.
|
||||
|
||||
Algorithm selection is explicit via ``algorithm_name`` (from
|
||||
``cfg.algorithm``).
|
||||
|
||||
This is fully registry-driven — adding a new algorithm only requires
|
||||
registering an ``RLAlgorithmConfig`` subclass; no changes here.
|
||||
|
||||
The returned algorithm has **no optimizers** yet. On the learner side,
|
||||
call ``algorithm.make_optimizers()`` afterwards to create them. On the
|
||||
actor side (inference-only), leave them empty.
|
||||
|
||||
Args:
|
||||
policy: Instantiated policy (e.g. ``SACPolicy``).
|
||||
policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters
|
||||
expected by the algorithm config's ``from_policy_config`` class-method.
|
||||
algorithm_name: Algorithm registry key to instantiate.
|
||||
"""
|
||||
known = RLAlgorithmConfig.get_known_choices()
|
||||
if algorithm_name not in known:
|
||||
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
|
||||
|
||||
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
|
||||
algo_config = config_cls.from_policy_config(policy_cfg)
|
||||
return algo_config.build_algorithm(policy)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTAlgorithm",
|
||||
"RLTAlgorithmConfig",
|
||||
"make_algorithm",
|
||||
]
|
||||
@@ -1,183 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Base classes for RL algorithms.
|
||||
|
||||
Defines the abstract interface that every algorithm must implement, a registry
|
||||
for algorithm configs, and a dataclass for training statistics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
# Generic containers for all algorithms
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""Construct the :class:`RLAlgorithm` for this config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
"""Whether this algorithm has an offline pretraining phase.
|
||||
|
||||
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
|
||||
return ``True`` here. The learner checks this before the main online
|
||||
loop and routes to :meth:`offline_update` accordingly.
|
||||
"""
|
||||
return False
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One offline training step (called before any online collection).
|
||||
|
||||
Only called when :meth:`supports_offline_phase` returns ``True``.
|
||||
Uses the same iterator protocol as :meth:`update`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not implement offline_update(). "
|
||||
"Either override this method or return False from supports_offline_phase()."
|
||||
)
|
||||
|
||||
def transition_to_online(self) -> None: # noqa: B027
|
||||
"""Called once when switching from offline to online phase.
|
||||
|
||||
Use this to freeze modules trained offline, rebuild optimizers for the
|
||||
online phase, reset step counters, etc.
|
||||
|
||||
Default is a no-op; subclasses override when they have an offline phase.
|
||||
"""
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner (inverse of ``get_weights``)."""
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Pre-compute observation features (e.g. frozen encoder cache).
|
||||
|
||||
Returns ``(None, None)`` when caching is not applicable.
|
||||
"""
|
||||
return None, None
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"]
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""RLT-specific hyper-parameters that control the update loop."""
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
chunk_stride: int = 2
|
||||
|
||||
# ── Update cadence ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
clip_grad_norm: float = 10.0
|
||||
|
||||
# ── Learning rates ──
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
|
||||
# ── TD learning ──
|
||||
discount: float = 0.99
|
||||
tau: float = 0.005
|
||||
num_critics: int = 2
|
||||
|
||||
# ── Policy constraint (paper Eq. 5) ──
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
|
||||
# ── Offline RL-token training ──
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
|
||||
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
|
||||
return cls(
|
||||
chunk_size=policy_cfg.chunk_size,
|
||||
chunk_stride=policy_cfg.chunk_stride,
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.clip_grad_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
rl_token_lr=policy_cfg.rl_token_lr,
|
||||
discount=policy_cfg.discount,
|
||||
tau=policy_cfg.tau,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
bc_reg_coeff=policy_cfg.bc_reg_coeff,
|
||||
ref_dropout=policy_cfg.ref_dropout,
|
||||
vla_finetune_weight=policy_cfg.vla_finetune_weight,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
return RLTAlgorithm(policy=policy, config=self)
|
||||
@@ -1,319 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) algorithm.
|
||||
|
||||
Implements the two-stage training from "RL Token: Bootstrapping Online RL
|
||||
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
|
||||
|
||||
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
|
||||
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
|
||||
reference-action pass-through, and reference-action dropout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
class RLTCritic(nn.Module):
|
||||
"""Q-function over (state, action_chunk) pairs.
|
||||
|
||||
Paper Eq. 3: Q_psi(x, a_{1:C})
|
||||
|
||||
Training-only component — lives on the algorithm side, not in the policy.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
|
||||
super().__init__()
|
||||
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
|
||||
|
||||
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
|
||||
x = torch.cat([state, action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class RLTAlgorithm(RLAlgorithm):
|
||||
"""RL Token: lightweight actor-critic on frozen VLA features.
|
||||
|
||||
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
|
||||
ensemble, and target networks. All VLA-specific logic (embedding
|
||||
extraction, reference actions) lives in ``_prepare_forward_batch``.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._is_online = False
|
||||
|
||||
self._init_critics()
|
||||
self._move_to_device()
|
||||
|
||||
# ── Initialization ───────────────────────────────────────────────
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
state_dim = self.policy._state_dim
|
||||
action_chunk_dim = self.policy._action_chunk_dim
|
||||
hidden_dims = self.policy.config.critic.hidden_dims
|
||||
|
||||
self.critics = torch.nn.ModuleList(
|
||||
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
|
||||
)
|
||||
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
|
||||
for ct in self.critic_targets:
|
||||
ct.requires_grad_(False)
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.critics.to(self._device)
|
||||
self.critic_targets.to(self._device)
|
||||
|
||||
# ── Offline phase (Stage 1): RL-token training ───────────────────
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
return True
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Train RL-token encoder/decoder on demonstration data.
|
||||
|
||||
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
|
||||
"""
|
||||
batch = next(batch_iterator)
|
||||
|
||||
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
|
||||
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
|
||||
|
||||
z_rl = self.policy.rl_token_encoder(z_vla)
|
||||
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
|
||||
|
||||
loss_ro = F.mse_loss(z_reconstructed, z_vla)
|
||||
|
||||
self.optimizers["rl_token"].zero_grad()
|
||||
loss_ro.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
)
|
||||
self.optimizers["rl_token"].step()
|
||||
|
||||
self._optimization_step += 1
|
||||
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
|
||||
|
||||
def transition_to_online(self) -> None:
|
||||
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
|
||||
self.policy.rl_token_encoder.requires_grad_(False)
|
||||
self.policy.rl_token_decoder.requires_grad_(False)
|
||||
self._is_online = True
|
||||
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
self._optimization_step = 0
|
||||
|
||||
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One full RLT update step with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
|
||||
the last batch also updates the actor (every ``policy_update_freq`` steps).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
self._critic_step(fb)
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
critic_loss = self._critic_step(fb)
|
||||
|
||||
stats = TrainingStats(losses={"loss_critic": critic_loss})
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
actor_loss, bc_loss, q_val = self._actor_step(fb)
|
||||
stats.losses["loss_actor"] = actor_loss
|
||||
stats.extra["bc_loss"] = bc_loss
|
||||
stats.extra["q_value_mean"] = q_val
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Convert a replay batch into algorithm-ready tensors.
|
||||
|
||||
Extracts RL-token from VLA embeddings, builds RL state, reads
|
||||
reference action from complementary_info.
|
||||
"""
|
||||
obs = batch["state"]
|
||||
next_obs = batch["next_state"]
|
||||
device = self._device
|
||||
|
||||
vla_emb = obs["observation.vla_embeddings"].to(device)
|
||||
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
z_rl = self.policy.rl_token_encoder(vla_emb)
|
||||
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
|
||||
|
||||
parts = [z_rl]
|
||||
next_parts = [z_rl_next]
|
||||
if "observation.state" in obs and self.policy._proprioception_dim > 0:
|
||||
prop = obs["observation.state"].to(device)
|
||||
next_prop = next_obs["observation.state"].to(device)
|
||||
parts.append(prop)
|
||||
next_parts.append(next_prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
next_state = torch.cat(next_parts, dim=-1)
|
||||
|
||||
action = batch[ACTION].to(device)
|
||||
reward = batch["reward"].to(device)
|
||||
done = batch["done"].to(device)
|
||||
|
||||
ref_action = None
|
||||
comp_info = batch.get("complementary_info")
|
||||
if comp_info is not None and "reference_action" in comp_info:
|
||||
ref_action = comp_info["reference_action"].to(device)
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"next_state": next_state,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"reference_action": ref_action,
|
||||
}
|
||||
|
||||
def _critic_step(self, fb: dict[str, Any]) -> float:
|
||||
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
|
||||
state = fb["state"]
|
||||
next_state = fb["next_state"]
|
||||
action = fb["action"]
|
||||
reward = fb["reward"]
|
||||
done = fb["done"]
|
||||
|
||||
with torch.no_grad():
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(action)
|
||||
next_action = self.policy.actor(next_state, ref)
|
||||
|
||||
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
|
||||
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
|
||||
|
||||
discount_chunk = self.config.discount**self.config.chunk_size
|
||||
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
|
||||
|
||||
q_preds = [c(state, action) for c in self.critics]
|
||||
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
|
||||
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["critic"].step()
|
||||
return loss.item()
|
||||
|
||||
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
|
||||
"""Paper Eq. 5: maximize Q while staying near VLA reference.
|
||||
|
||||
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
|
||||
With reference-action dropout applied to the actor's ref input.
|
||||
"""
|
||||
state = fb["state"]
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
|
||||
|
||||
# Reference-action dropout (paper Section IV-B)
|
||||
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
|
||||
ref_input = ref * mask
|
||||
|
||||
action = self.policy.actor(state, ref_input)
|
||||
|
||||
q_value = self.critics[0](state, action)
|
||||
|
||||
bc_loss = F.mse_loss(action, ref)
|
||||
|
||||
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
|
||||
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
return loss.item(), bc_loss.item(), q_value.mean().item()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.tau
|
||||
for critic, target in zip(self.critics, self.critic_targets, strict=True):
|
||||
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
|
||||
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
|
||||
|
||||
# ── Optimizer management ─────────────────────────────────────────
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create optimizers. Initially for RL-token (Stage 1)."""
|
||||
self.optimizers = {
|
||||
"rl_token": torch.optim.Adam(
|
||||
list(self.policy.rl_token_encoder.parameters())
|
||||
+ list(self.policy.rl_token_decoder.parameters()),
|
||||
lr=self.config.rl_token_lr,
|
||||
),
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
# ── Weight sync ──────────────────────────────────────────────────
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Push actor + RL-token encoder to actors (small footprint)."""
|
||||
weights = {
|
||||
"actor": self.policy.actor.state_dict(),
|
||||
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
|
||||
}
|
||||
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
if "actor" in weights:
|
||||
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
|
||||
if "rl_token_encoder" in weights:
|
||||
self.policy.rl_token_encoder.load_state_dict(
|
||||
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
|
||||
)
|
||||
@@ -1,81 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC-specific hyper-parameters that control the update loop."""
|
||||
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
clip_grad_norm: float = 40.0
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
num_discrete_actions: int | None = None
|
||||
shared_encoder: bool = True
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
use_torch_compile: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
|
||||
return cls(
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.grad_clip_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
temperature_lr=policy_cfg.temperature_lr,
|
||||
discount=policy_cfg.discount,
|
||||
temperature_init=policy_cfg.temperature_init,
|
||||
target_entropy=policy_cfg.target_entropy,
|
||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
||||
num_discrete_actions=policy_cfg.num_discrete_actions,
|
||||
shared_encoder=policy_cfg.shared_encoder,
|
||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
use_torch_compile=policy_cfg.use_torch_compile,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
@@ -1,409 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC (Soft Actor-Critic) algorithm.
|
||||
|
||||
This module encapsulates all SAC-specific training logic (critic, actor,
|
||||
temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.sac.modeling_sac import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
CriticEnsemble,
|
||||
CriticHead,
|
||||
DiscreteCritic,
|
||||
SACObservationEncoder,
|
||||
SACPolicy,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic with optional discrete-critic head.
|
||||
|
||||
Owns the ``SACPolicy`` and its optimizers. All loss methods call
|
||||
``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor``
|
||||
directly, so any policy that returns ``{"action", "log_prob"}`` from its
|
||||
``forward()`` is compatible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SACPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._init_critic_encoder()
|
||||
self._init_critics()
|
||||
self._init_temperature()
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critic_encoder(self) -> None:
|
||||
"""Build or share the encoder used by critics."""
|
||||
if self.config.shared_encoder:
|
||||
self.critic_encoder = self.policy.encoder
|
||||
self.policy.actor.encoder_is_shared = True
|
||||
else:
|
||||
self.critic_encoder = SACObservationEncoder(self.policy.config)
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
input_dim = self.critic_encoder.output_dim + action_dim
|
||||
|
||||
heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads)
|
||||
|
||||
target_heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critic_target()
|
||||
|
||||
def _init_discrete_critic_target(self) -> None:
|
||||
"""Build only the target discrete critic."""
|
||||
input_dim = self.critic_encoder.output_dim
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.critic_encoder,
|
||||
input_dim=input_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO: (kmeftah) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and default target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
"""Move algorithm-owned modules to the policy device."""
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if hasattr(self, "discrete_critic_target"):
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Run one full SAC update with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches from ``batch_iterator``. The first
|
||||
``utd_ratio - 1`` batches are used for critic-only warm-up steps;
|
||||
the last batch drives the full update (critic + actor + temperature).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
critic_loss_val = loss_critic.item()
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": critic_loss_val},
|
||||
grad_norms={"critic": critic_grad_norm},
|
||||
)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_discrete.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
actor_loss = self._compute_loss_actor(forward_batch)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
temp_loss = self._compute_loss_temperature(forward_batch)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_(
|
||||
[self.log_alpha],
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = actor_loss.item()
|
||||
stats.losses["loss_temperature"] = temp_loss.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features})
|
||||
next_actions = next_output["action"]
|
||||
next_log_probs = next_output["log_prob"]
|
||||
|
||||
q_targets = self.critic_target(next_observations, next_actions, next_obs_features)
|
||||
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
min_q, _ = q_targets.min(dim=0)
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
actions = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions, obs_features)
|
||||
|
||||
td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete).long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features)
|
||||
best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
target_next_qs = self.discrete_critic_target(next_observations, next_obs_features)
|
||||
target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1)
|
||||
|
||||
rewards_disc = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_disc = rewards + discrete_penalties
|
||||
target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q
|
||||
|
||||
predicted_qs = self.policy.discrete_critic(observations, obs_features)
|
||||
predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
return F.mse_loss(input=predicted_q, target=target_q)
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
actions_pi = output["action"]
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions_pi, obs_features)
|
||||
min_q = q_preds.min(dim=0)[0]
|
||||
|
||||
return ((self.temperature * log_probs) - min_q).mean()
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.critic_target_update_weight
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Build the dict expected by loss computation from a sampled batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create Adam optimizers for the SAC components and store them."""
|
||||
actor_params = [
|
||||
p
|
||||
for n, p in self.policy.actor.named_parameters()
|
||||
if not self.config.shared_encoder or not n.startswith("encoder")
|
||||
]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors (includes actor + discrete critic)."""
|
||||
return move_state_dict_to_device(self.policy.state_dict(), device="cpu")
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
state = move_state_dict_to_device(weights, device=device)
|
||||
self.policy.load_state_dict(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
if not self.config.shared_encoder:
|
||||
return None, None
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
if not self.policy.encoder.has_images:
|
||||
return None, None
|
||||
observation_features = self.policy.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.encoder.get_cached_image_features(next_observations)
|
||||
return observation_features, next_observation_features
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies.
|
||||
|
||||
Subclasses must implement ``sample(batch_size)`` and may override
|
||||
``get_iterator`` for specialised iteration.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
...
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches.
|
||||
|
||||
The default implementation repeatedly calls ``self.sample()``.
|
||||
Subclasses with underlying buffer iterators (async prefetch)
|
||||
should override this for better throughput.
|
||||
"""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an optional offline replay buffer.
|
||||
|
||||
When both buffers are present, each batch is constructed by sampling
|
||||
``ceil(batch_size * online_ratio)`` from the online buffer and the
|
||||
remainder from the offline buffer, then concatenating.
|
||||
|
||||
This mixer assumes both online and offline buffers are present.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer | None = None,
|
||||
online_ratio: float = 1.0,
|
||||
):
|
||||
if not 0.0 <= online_ratio <= 1.0:
|
||||
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
|
||||
self.online_buffer = online_buffer
|
||||
self.offline_buffer = offline_buffer
|
||||
self.online_ratio = online_ratio
|
||||
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
if self.offline_buffer is None:
|
||||
return self.online_buffer.sample(batch_size)
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
n_offline = batch_size - n_online
|
||||
|
||||
online_batch = self.online_buffer.sample(n_online)
|
||||
offline_batch = self.offline_buffer.sample(n_offline)
|
||||
return concatenate_batch_transitions(online_batch, offline_batch)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches from online/offline mixed sampling."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
@@ -36,7 +36,6 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -380,7 +379,6 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -610,14 +608,7 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -665,7 +656,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -734,8 +725,6 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
+284
-92
@@ -65,11 +65,9 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -95,7 +93,7 @@ from lerobot.utils.train_utils import (
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.transition import move_transition_to_device
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
@@ -266,8 +264,8 @@ def add_actor_information_and_train(
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``).
|
||||
- Periodically pushes updated weights to actors.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
@@ -286,15 +284,17 @@ def add_actor_information_and_train(
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.async_prefetch
|
||||
queue_size = cfg.queue_size
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -306,7 +306,7 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
|
||||
policy = make_policy(
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
@@ -315,24 +315,19 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(
|
||||
policy=policy,
|
||||
policy_cfg=cfg.policy,
|
||||
algorithm_name=cfg.algorithm,
|
||||
)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
preprocessor, postprocessor = None, None
|
||||
|
||||
# Push initial policy weights to actors (same path as periodic push)
|
||||
state_bytes = state_to_bytes(algorithm.get_weights())
|
||||
parameters_queue.put(state_bytes)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
total_batch_size = cfg.batch_size
|
||||
batch_size = cfg.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset is not None:
|
||||
@@ -341,70 +336,20 @@ def add_actor_information_and_train(
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
|
||||
# DataMixer: online-only or online/offline 50-50 mix
|
||||
data_mixer = OnlineOfflineMixer(
|
||||
online_buffer=replay_buffer,
|
||||
offline_buffer=offline_replay_buffer,
|
||||
online_ratio=cfg.online_ratio,
|
||||
)
|
||||
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=total_batch_size,
|
||||
preprocessor=preprocessor,
|
||||
action_dim=cfg.policy.output_features["action"].shape[0],
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
algorithm.optimization_step = optimization_step
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
|
||||
offline_steps = getattr(cfg.policy, "offline_steps", 0)
|
||||
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
|
||||
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
|
||||
offline_mixer = OnlineOfflineMixer(
|
||||
online_buffer=offline_replay_buffer,
|
||||
offline_buffer=None,
|
||||
online_ratio=1.0,
|
||||
)
|
||||
offline_iterator = algorithm.configure_data_iterator(
|
||||
data_mixer=offline_mixer,
|
||||
batch_size=total_batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
for step in range(offline_steps):
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
|
||||
return
|
||||
|
||||
stats = algorithm.offline_update(offline_iterator)
|
||||
|
||||
if step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
|
||||
if wandb_logger:
|
||||
log_dict = stats.to_log_dict()
|
||||
log_dict["offline_step"] = step
|
||||
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
|
||||
|
||||
algorithm.transition_to_online()
|
||||
optimizers = algorithm.get_optimizers()
|
||||
logging.info("[LEARNER] Offline phase complete, transitioned to online")
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
@@ -435,22 +380,180 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
if online_iterator is None:
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
|
||||
stats = trainer.training_step()
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
state_dicts = algorithm.get_weights()
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
training_infos = stats.to_log_dict()
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
optimization_step = algorithm.optimization_step
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
@@ -478,6 +581,7 @@ def add_actor_information_and_train(
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
@@ -494,8 +598,6 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,8 +682,6 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
@@ -605,8 +705,6 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
preprocessor: Optional preprocessor pipeline to save
|
||||
postprocessor: Optional postprocessor pipeline to save
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
@@ -623,8 +721,6 @@ def save_training_checkpoint(
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
@@ -662,6 +758,58 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
# Training setup functions
|
||||
|
||||
|
||||
@@ -866,6 +1014,33 @@ def initialize_offline_replay_buffer(
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
@@ -916,6 +1091,23 @@ def check_nan_in_transition(
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType, *, action_dim: int | None = None) -> BatchType:
|
||||
"""Apply a policy preprocessor to an RL batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
actions = batch[ACTION]
|
||||
|
||||
extra_action = None
|
||||
if action_dim is not None and actions.shape[-1] > action_dim:
|
||||
extra_action = actions[..., action_dim:]
|
||||
actions = actions[..., :action_dim]
|
||||
|
||||
obs_action = {**observations, ACTION: actions}
|
||||
obs_action = preprocessor(obs_action)
|
||||
batch["state"] = {k: v for k, v in obs_action.items() if k.startswith("observation.")}
|
||||
batch[ACTION] = obs_action[ACTION]
|
||||
|
||||
if extra_action is not None:
|
||||
batch[ACTION] = torch.cat([batch[ACTION], extra_action], dim=-1)
|
||||
|
||||
next_obs = {**next_observations}
|
||||
next_obs = preprocessor(next_obs)
|
||||
batch["next_state"] = {k: v for k, v in next_obs.items() if k.startswith("observation.")}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class _PreprocessedIterator:
|
||||
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
||||
|
||||
__slots__ = ("_raw", "_preprocessor", "_action_dim")
|
||||
|
||||
def __init__(
|
||||
self, raw_iterator: Iterator[BatchType], preprocessor: Any, action_dim: int | None = None
|
||||
) -> None:
|
||||
self._raw = raw_iterator
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
|
||||
def __iter__(self) -> _PreprocessedIterator:
|
||||
return self
|
||||
|
||||
def __next__(self) -> BatchType:
|
||||
batch = next(self._raw)
|
||||
return preprocess_rl_batch(self._preprocessor, batch, action_dim=self._action_dim)
|
||||
|
||||
|
||||
class RLTrainer:
|
||||
"""Unified training step orchestrator.
|
||||
|
||||
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: RLAlgorithm,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
preprocessor: Any | None = None,
|
||||
action_dim: int | None = None,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
self.algorithm = algorithm
|
||||
self.data_mixer = data_mixer
|
||||
self.batch_size = batch_size
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
self.async_prefetch = async_prefetch
|
||||
self.queue_size = queue_size
|
||||
|
||||
self._iterator: Iterator[BatchType] | None = None
|
||||
|
||||
self.algorithm.make_optimizers()
|
||||
|
||||
def _build_data_iterator(self) -> Iterator[BatchType]:
|
||||
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
||||
raw = self.algorithm.configure_data_iterator(
|
||||
data_mixer=self.data_mixer,
|
||||
batch_size=self.batch_size,
|
||||
async_prefetch=self.async_prefetch,
|
||||
queue_size=self.queue_size,
|
||||
)
|
||||
if self._preprocessor is not None:
|
||||
return _PreprocessedIterator(raw, self._preprocessor, self._action_dim)
|
||||
return raw
|
||||
|
||||
def reset_data_iterator(self) -> None:
|
||||
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
||||
self._iterator = None
|
||||
|
||||
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
||||
"""Swap the active data mixer, optionally resetting the iterator."""
|
||||
self.data_mixer = data_mixer
|
||||
if reset:
|
||||
self.reset_data_iterator()
|
||||
|
||||
def training_step(self) -> TrainingStats:
|
||||
"""Run one training step (algorithm-agnostic)."""
|
||||
if self._iterator is None:
|
||||
self._iterator = self._build_data_iterator()
|
||||
return self.algorithm.update(self._iterator)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user