mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
143 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7788db7838 | |||
| d883c78a94 | |||
| de42da8225 | |||
| d0d714be47 | |||
| 7d9b469eee | |||
| 6db39cad58 | |||
| af0676f99e | |||
| b9df1a4ac5 | |||
| 5361346bec | |||
| f0b969ae48 | |||
| a9d54cbddb | |||
| c5a029a28a | |||
| c8163662ad | |||
| 376cc772ff | |||
| d1eefd4e97 | |||
| 7a03223693 | |||
| f840d2e006 | |||
| e94844fa59 | |||
| 990f8e9cc9 | |||
| 6ce2a00135 | |||
| bf90efa7e1 | |||
| 5b4ac3068e | |||
| dbe3406a69 | |||
| 1785767e61 | |||
| afd833f49e | |||
| 2234b851c0 | |||
| e4a214d890 | |||
| e8438aac59 | |||
| 8fe977118b | |||
| d09b2a28af | |||
| f2530570e0 | |||
| 8567ab60d8 | |||
| 9784123463 | |||
| 4c2add41d7 | |||
| a19d7fb6bf | |||
| 565c992589 | |||
| 96cc634a66 | |||
| b044f3104b | |||
| 384ec52ec7 | |||
| 8d1434c069 | |||
| f613a37cd2 | |||
| 494aa576b2 | |||
| 514625a7f6 | |||
| 9f7bfeb419 | |||
| aa40c8c813 | |||
| d36bdac114 | |||
| ff1666b216 | |||
| c57d3a9688 | |||
| 9ae11a087d | |||
| 21e63b505f | |||
| e9e7eb827a | |||
| ac323b0113 | |||
| b028907d21 | |||
| 2eafcc7ca1 | |||
| b3b57a8288 | |||
| eaaf1c1766 | |||
| 3bc3bf0391 | |||
| 8c5fe10d6c | |||
| 8178a06b90 | |||
| 9ea8bd029c | |||
| bd5c264c49 | |||
| 5c628f1700 | |||
| d602e8169c | |||
| 49baccdccb | |||
| 9beafe0c19 | |||
| 27c9db60a6 | |||
| fda5fb5e94 | |||
| 5f5438d6fa | |||
| 2b779cd6c6 | |||
| 3886af42a5 | |||
| 38f7229078 | |||
| 504421949c | |||
| 28b9efc04f | |||
| abba423e28 | |||
| 47a81c4150 | |||
| 6a3d57031a | |||
| d74494d92b | |||
| 1ba896598e | |||
| 61e55830da | |||
| b7522da85d | |||
| 98dc053e6d | |||
| bbff93d20d | |||
| 32c1649085 | |||
| eb564f8ddb | |||
| a2958a8e0c | |||
| 8f1679f309 | |||
| b1473f11c8 | |||
| 7b556079d8 | |||
| e91a773b93 | |||
| a9bd67eae9 | |||
| 4a4ac759ec | |||
| 7dd8e015f8 | |||
| af2960c33e | |||
| a36e4619ad | |||
| b397a757bb | |||
| 92adf2218f | |||
| f3614dd812 | |||
| b23b7a5bd7 | |||
| 882c80d446 | |||
| 6ff5f318b2 | |||
| 2eae751977 | |||
| 894878039d | |||
| ab72471dda | |||
| 23849e0cb8 | |||
| cb18fc07ef | |||
| 440e22c184 | |||
| 28b69bf8ba | |||
| b997fdde96 | |||
| 6f975cf576 | |||
| 2688731064 | |||
| 61b0eeae4b | |||
| fe20437b62 | |||
| ff861ba869 | |||
| 577cd10974 | |||
| 4be3942cbc | |||
| fd5afdfbf0 | |||
| 8d2c66abd2 | |||
| b0923ab74b | |||
| 7f70b78f32 | |||
| afad90ffaa | |||
| f5091448a8 | |||
| cc46497f4c | |||
| 5d25f5bd40 | |||
| ce83752f16 | |||
| 4ed6cf159d | |||
| 7626d26e6a | |||
| 14a59f576b | |||
| eb3649292b | |||
| ac0993c2e3 | |||
| 55198de096 | |||
| c20bf75ba0 | |||
| a25480d363 | |||
| 4c19a71d7c | |||
| 0878c6880f | |||
| d2684d41cd | |||
| 4e76c1f88c | |||
| 11e6bd762a | |||
| ce3b9f627e | |||
| 3bf0c19be7 | |||
| ad4f510262 | |||
| 9124b36b0a | |||
| 4bc356b7f3 | |||
| 21a961ecbb |
@@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --some.option=true
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
|
||||
@@ -29,8 +29,8 @@ on:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
# Ensures that only the latest commit is built, canceling older runs.
|
||||
concurrency:
|
||||
|
||||
@@ -44,7 +44,7 @@ test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
@@ -68,12 +68,12 @@ test-act-ete-train:
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
|
||||
test-act-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
@@ -82,7 +82,7 @@ test-act-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
@@ -106,7 +106,7 @@ test-diffusion-ete-train:
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
@@ -115,7 +115,7 @@ test-diffusion-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--policy.push_to_hub=false \
|
||||
@@ -137,7 +137,7 @@ test-tdmpc-ete-train:
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
@@ -148,7 +148,7 @@ test-tdmpc-ete-eval:
|
||||
|
||||
|
||||
test-smolvla-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
@@ -171,7 +171,7 @@ test-smolvla-ete-train:
|
||||
--output_dir=tests/outputs/smolvla/
|
||||
|
||||
test-smolvla-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
@@ -276,7 +276,7 @@ Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
@@ -288,10 +288,10 @@ python -m lerobot.scripts.eval \
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
See `lerobot-eval --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
@@ -303,7 +303,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
|
||||
|
||||
\<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/wandb.png" alt="WandB logs example"\>
|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
@@ -311,7 +311,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install system dependencies and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
|
||||
@@ -19,21 +19,16 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: libero
|
||||
title: Using LIBERO
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
title: "Policies"
|
||||
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
title: Introduction to Robot Processors
|
||||
- local: implement_your_own_processor
|
||||
title: Implement your own processor
|
||||
- local: processors_robots_teleop
|
||||
title: Processors for Robots and Teleoperators
|
||||
title: "Robot Processors"
|
||||
- sections:
|
||||
- local: hope_jr
|
||||
title: Hope Jr
|
||||
- local: so101
|
||||
title: SO-101
|
||||
- local: so100
|
||||
@@ -42,16 +37,14 @@
|
||||
title: Koch v1.1
|
||||
- local: lekiwi
|
||||
title: LeKiwi
|
||||
- local: hope_jr
|
||||
title: Hope Jr
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Feetech Motor Firmware Update
|
||||
|
||||
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Windows computer (Feetech software is only available for Windows)
|
||||
- Feetech motor control board
|
||||
- USB cable to connect the control board to your computer
|
||||
- Feetech motors connected to the control board
|
||||
|
||||
## Step 1: Download Feetech Software
|
||||
|
||||
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
|
||||
2. Download the latest version of the Feetech debugging software (FD)
|
||||
3. Install the software on your Windows computer
|
||||
|
||||
## Step 2: Hardware Setup
|
||||
|
||||
1. Connect your Feetech motors to the motor control board
|
||||
2. Connect the motor control board to your Windows computer via USB cable
|
||||
3. Ensure power is supplied to the motors
|
||||
|
||||
## Step 3: Configure Connection
|
||||
|
||||
1. Launch the Feetech debugging software
|
||||
2. Select the correct COM port from the port dropdown menu
|
||||
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
|
||||
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
|
||||
4. Click "Open" to establish communication with the control board
|
||||
|
||||
## Step 4: Scan for Motors
|
||||
|
||||
1. Once connected, click the "Search" button to detect all connected motors
|
||||
2. The software will automatically discover and list all motors on the bus
|
||||
3. Each motor will appear with its ID number
|
||||
|
||||
## Step 5: Update Firmware
|
||||
|
||||
For each motor you want to update:
|
||||
|
||||
1. **Select the motor** from the list by clicking on it
|
||||
2. **Click on Upgrade tab**:
|
||||
3. **Click on Online button**:
|
||||
- If an potential firmware update is found, it will be displayed in the box
|
||||
4. **Click on Upgrade button**:
|
||||
- The update progress will be displayed
|
||||
|
||||
## Step 6: Verify Update
|
||||
|
||||
1. After the update completes, the software should automatically refresh the motor information
|
||||
2. Verify that the firmware version has been updated to the expected version
|
||||
|
||||
## Important Notes
|
||||
|
||||
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
|
||||
|
||||
## Bonus: Motor Debugging on Linux/macOS
|
||||
|
||||
For debugging purposes only, you can use the open-source Feetech Debug Tool:
|
||||
|
||||
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
|
||||
|
||||
### Installation Instructions
|
||||
|
||||
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
|
||||
|
||||
**Limitations:**
|
||||
|
||||
- This tool is for debugging and parameter adjustment only
|
||||
- Firmware updates must still be done on Windows with official Feetech software
|
||||
@@ -412,7 +412,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
To train the classifier, use the `train.py` script with your configuration:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json
|
||||
lerobot-train --config_path path/to/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
**Deploying and Testing the Model**
|
||||
@@ -458,7 +458,7 @@ The reward classifier will automatically provide rewards based on the visual inp
|
||||
3. **Train the classifier**:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
4. **Test the classifier**:
|
||||
|
||||
+11
-11
@@ -19,7 +19,7 @@ pip install -e ".[hopejr]"
|
||||
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
|
||||
@@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati
|
||||
### 1.1 Calibrate Robot Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav
|
||||
### 1.2 Calibrate Teleoperator Glove
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=red \
|
||||
@@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo
|
||||
### 1.3 Calibrate Robot Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white
|
||||
@@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e
|
||||
### 1.4 Calibrate Teleoperator Exoskeleton
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_arm \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=black
|
||||
@@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar
|
||||
### Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -194,7 +194,7 @@ python -m lerobot.teleoperate \
|
||||
### Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white \
|
||||
@@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental.
|
||||
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -236,7 +236,7 @@ python -m lerobot.record \
|
||||
### Replay
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -248,7 +248,7 @@ python -m lerobot.replay \
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
@@ -263,7 +263,7 @@ python -m lerobot.scripts.train \
|
||||
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
|
||||
+10
-21
@@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file
|
||||
<hfoptions id="teleoperate_so101">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoptions id="teleoperate_koch_camera">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th
|
||||
<hfoptions id="record">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or
|
||||
<hfoptions id="replay">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
@@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
@@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
@@ -519,14 +519,11 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -538,7 +535,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -547,7 +544,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -562,12 +559,6 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -577,8 +568,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -96,10 +96,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online](
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
|
||||
@@ -1,323 +0,0 @@
|
||||
# Implement your own Robot Processor
|
||||
|
||||
In this tutorial, you'll learn how to implement your own Robot Processor.
|
||||
It begins by exploring the need for a custom processor, then uses the Normalization processors as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
|
||||
|
||||
## Why would you need a custom processor?
|
||||
|
||||
In most cases, when reading raw data from a sensor like the camera and robot motor encoders,
|
||||
you will need to process this data to transform it into a format that is compatible to use with the policies in LeRobot.
|
||||
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
|
||||
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
|
||||
|
||||
For example, in LeRobot's `VanillaObservationProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
|
||||
|
||||
```python
|
||||
# Input: numpy array with shape (480, 640, 3) and dtype uint8
|
||||
raw_image = env_observation["pixels"] # Values in [0, 255]
|
||||
|
||||
# After processing: torch tensor with shape (1, 3, 480, 640) and dtype float32
|
||||
processed_image = processor(transition)["observation"]["observation.image"] # Values in [0, 1]
|
||||
```
|
||||
|
||||
On the other hand, when a model returns a certain action to be executed on the robot, it is often that one has to post-process this action to make it compatible to run on the robot.
|
||||
For example, the model might return joint positions values that range from `[-1, 1]` and one would need to scale them to the ranges of the minimum and maximum joint angle positions of the robot.
|
||||
|
||||
In LeRobot, this normalization workflow is handled by the `NormalizerProcessor` (for inputs) and the `UnnormalizerProcessor` (for outputs). These processors are heavily used by policies (e.g., Pi0, SmolVLA) and integrate tightly with the `RobotProcessor`'s `get_config`, `state_dict`, and `load_state_dict` APIs.
|
||||
|
||||
For instance, `UnnormalizerProcessor` converts model outputs in `[-1, 1]` back to actual robot joint ranges:
|
||||
|
||||
```python
|
||||
# Input: model action with normalized values in [-1, 1]
|
||||
normalized_action = torch.tensor([-0.5, 0.8, -1.0, 0.2]) # Model output
|
||||
|
||||
# After post-processing: real joint positions in robot's native ranges
|
||||
# Example: joints range from [-180.0, 180.0]
|
||||
real_action = unnormalizer(transition)["action"]
|
||||
# real action after post-processing: [ -90., 144., -180., 36.]
|
||||
```
|
||||
|
||||
The unnormalizer uses the dataset statistics to convert back:
|
||||
|
||||
```python
|
||||
# For MIN_MAX normalization: action = (normalized + 1) * (max - min) / 2 + min
|
||||
real_action = (normalized_action + 1) * (max_val - min_val) / 2 + min_val
|
||||
```
|
||||
|
||||
All these situations point us towards the need for a mechanism to preprocess the data before being passed to the policies and then post-process the action that are returned to be executed on the robot.
|
||||
|
||||
To that end, LeRobot provides a pipeline mechanism to implement a sequence of processing steps for the input data and the output action.
|
||||
|
||||
## How to implement your own processor?
|
||||
|
||||
We'll use the `NormalizerProcessor` as a concrete running example because it is central to most policies and demonstrates configuration and state serialization cleanly.
|
||||
|
||||
Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods:
|
||||
|
||||
- `__call__`: implements the processing step for the input transition.
|
||||
- `get_config`: gets the configuration of the processor step.
|
||||
- `state_dict`: gets the state of the processor step.
|
||||
- `load_state_dict`: loads the state of the processor step.
|
||||
- `reset`: resets the state of the processor step.
|
||||
- `feature_contract`: displays the modification to the feature space during the processor step.
|
||||
|
||||
### Implement the `__call__` method
|
||||
|
||||
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessor` conceptually works (simplified):
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
@dataclass
|
||||
class NormalizerProcessor:
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, torch.Tensor]]
|
||||
eps: float = 1e-8
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
normalized_info = {}
|
||||
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
act = transition.get(TransitionKey.ACTION)
|
||||
|
||||
new_obs = self._normalize_observation(obs, normalized_info)
|
||||
new_act = self._normalize_action(act, normalized_info)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
# Record what was normalized into complementary_data
|
||||
if normalized_info:
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
comp = dict(comp)
|
||||
comp["normalized_keys"] = normalized_info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
|
||||
return new_transition
|
||||
```
|
||||
|
||||
See the full implementation in `src/lerobot/processor/normalize_processor.py` for details on mean/std and min/max modes and key selection.
|
||||
|
||||
**Key principles:**
|
||||
|
||||
- Always check if required data exists before processing
|
||||
- Return unchanged transition if no processing is needed
|
||||
- Use `transition.copy()` to avoid side effects
|
||||
- Only modify the specific keys your processor handles
|
||||
|
||||
**Tip**: For observation-only processors, you can inherit from `ObservationProcessor` to avoid writing `__call__` boilerplate. The normalizer is mixed (observations and actions), so it implements `__call__` directly.
|
||||
|
||||
### Configuration and State Management
|
||||
|
||||
Processors support serialization through three methods that separate configuration from tensor state. This is especially important for normalization processors, which carry dataset statistics (tensors) in their state, and hyperparameters in their config:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
@dataclass
|
||||
class NormalizerProcessor:
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
eps: float = 1e-8
|
||||
_tensor_stats: dict[str, dict[str, torch.Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""JSON-serializable configuration (no tensors)."""
|
||||
return {
|
||||
"eps": self.eps,
|
||||
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
|
||||
"norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Tensor state only (e.g., dataset statistics)."""
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore tensor state at runtime."""
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
# Save (e.g., inside a policy)
|
||||
config = processor.get_config()
|
||||
tensors = processor.state_dict()
|
||||
|
||||
# Restore (e.g., loading a pretrained policy)
|
||||
new_processor = NormalizerProcessor(**config)
|
||||
new_processor.load_state_dict(tensors)
|
||||
```
|
||||
|
||||
### Transform features
|
||||
|
||||
The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
|
||||
|
||||
Normalization typically preserves the feature keys and shapes, so `NormalizerProcessor.transform_features` returns the input features unchanged. When your processor renames or reshapes, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
|
||||
|
||||
```python
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# Simple renaming
|
||||
if "pixels" in features:
|
||||
features["observation.image"] = features.pop("pixels")
|
||||
|
||||
# Pattern-based renaming
|
||||
for key in list(features.keys()):
|
||||
if key.startswith("env_state."):
|
||||
suffix = key[len("env_state."):]
|
||||
features[f"observation.{suffix}"] = features.pop(key)
|
||||
|
||||
return features
|
||||
```
|
||||
|
||||
**Key principles:**
|
||||
|
||||
- Use `features.pop(old_key)` to remove and get the old feature
|
||||
- Use `features[new_key] = old_feature` to add the renamed feature
|
||||
- Always return the modified features dictionary
|
||||
- Document transformations clearly in the docstring
|
||||
|
||||
### Example of usage from the codebase
|
||||
|
||||
`transform_features` is used by `RobotProcessor` to derive the dataset/policy feature contract from an initial feature set by applying each step's transformation. You can see concrete examples in the codebase:
|
||||
|
||||
- Phone teleoperation record pipeline (`examples/phone_so100_record.py`): processors like `ForwardKinematicsJointsToEE`, `GripperVelocityToJoint`, and `EEBoundsAndSafety` implement `transform_features` to declare which action/observation keys should be materialized in the dataset.
|
||||
- SO100 follower kinematics (`src/lerobot/robots/so100_follower/robot_kinematic_processor.py`): each processor's `transform_features` method adds or refines feature keys such as `observation.state.ee.{x,y,z,wx,wy,wz}` or `action.gripper.pos`.
|
||||
- Rename and tokenizer processors (`src/lerobot/processor/rename_processor.py`, `src/lerobot/processor/tokenizer_processor.py`): demonstrate key renaming and adding language token features to the contract.
|
||||
|
||||
In practice, you will often aggregate features by running `RobotProcessor.transform_features(...)` with your initial features to compute the final contract before recording or training.
|
||||
|
||||
## Helper Classes
|
||||
|
||||
LeRobot provides pre-built processor classes for common transformations. Below is a comprehensive list of registered processors in the codebase.
|
||||
|
||||
### Core processors (observations, actions, normalization)
|
||||
|
||||
- **`VanillaObservationProcessor`** (`observation_processor`): Images and state processing to LeRobot format.
|
||||
- **`NormalizerProcessor`** (`normalizer_processor`): Normalize observations/actions (mean/std or min/max to [-1, 1]).
|
||||
- **`UnnormalizerProcessor`** (`unnormalizer_processor`): Inverse of the normalizer for model outputs.
|
||||
- **`DeviceProcessor`** (`device_processor`): Move tensors to a specific device (CPU/GPU) and optional float dtype.
|
||||
- **`ToBatchProcessor`** (`to_batch_processor`): Add batch dimension to observations/actions when missing.
|
||||
- **`RenameProcessor`** (`rename_processor`): Rename observation keys using a mapping dictionary.
|
||||
- **`TokenizerProcessor`** (`tokenizer_processor`): Tokenize language tasks into `observation.language.*` tensors.
|
||||
|
||||
### Teleoperation mapping processors
|
||||
|
||||
- **`MapDeltaActionToRobotAction`** (`map_delta_action_to_robot_action`): Map teleop deltas (e.g., gamepad) to `action.target_*` fields.
|
||||
- **`MapPhoneActionToRobotAction`** (`map_phone_action_to_robot_action`): Map calibrated phone pose/buttons to `action.target_*` and gripper.
|
||||
|
||||
### Robot kinematics processors (SO100 follower example)
|
||||
|
||||
- **`EEReferenceAndDelta`** (`ee_reference_and_delta`): Compute desired EE pose from target deltas and current pose.
|
||||
- **`EEBoundsAndSafety`** (`ee_bounds_and_safety`): Clip EE pose to bounds and check for jumps.
|
||||
- **`InverseKinematicsEEToJoints`** (`inverse_kinematics_ee_to_joints`): Convert EE pose to joint targets via IK.
|
||||
- **`GripperVelocityToJoint`** (`gripper_velocity_to_joint`): Convert gripper velocity input to joint position command.
|
||||
- **`ForwardKinematicsJointsToEE`** (`forward_kinematics_joints_to_ee`): Compute EE pose features from joint positions via FK.
|
||||
- **`AddRobotObservationAsComplimentaryData`** (`add_robot_observation`): Read robot observation and insert `raw_joint_positions` into complementary data.
|
||||
|
||||
### Policy-specific utility processors
|
||||
|
||||
- **`Pi0NewLineProcessor`** (`pi0_new_line_processor`): Ensure text tasks end with a newline (Pi0 tokenizer compatibility).
|
||||
- **`SmolVLANewLineProcessor`** (`smolvla_new_line_processor`): Ensure text tasks end with a newline (SmolVLA tokenizer compatibility).
|
||||
|
||||
### Usage Example
|
||||
|
||||
```python
|
||||
from lerobot.processor import NormalizerProcessor, DeviceProcessor, RobotProcessor, ToBatchProcessor
|
||||
|
||||
# Create a processing pipeline (typical policy preprocessor)
|
||||
steps = [
|
||||
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device="cuda"),
|
||||
]
|
||||
|
||||
# Use in RobotProcessor
|
||||
processor = RobotProcessor(steps=steps)
|
||||
processed_transition = processor(raw_transition)
|
||||
```
|
||||
|
||||
### Using overrides
|
||||
|
||||
You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `RobotProcessor.from_pretrained(...)`.
|
||||
|
||||
Example: during policy evaluation on the robot, override the device and rename map.
|
||||
Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset.
|
||||
|
||||
```437:445:src/lerobot/record.py
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
Direct usage with `from_pretrained`:
|
||||
|
||||
```python
|
||||
from lerobot.processor import RobotProcessor
|
||||
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"username/my-processor",
|
||||
overrides={
|
||||
"device_processor": {"device": "cuda:0"}, # registry name for registered steps
|
||||
"CustomStep": {"param": 42}, # class name for non-registered steps
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Keep processors atomic** - One transformation per processor for reusability and debugging
|
||||
- **Use dataclasses** - Clean initialization with `@dataclass`
|
||||
- **Always register processors** - Use `@ProcessorStepRegistry.register("name")` for discoverability
|
||||
- **Check for None** - Always validate required data exists before processing
|
||||
- **Use copy() for safety** - Avoid side effects with `transition.copy()`
|
||||
- **Separate config and state** - JSON-serializable config vs tensor state_dict
|
||||
- **Use base classes** - Inherit from `ObservationProcessor` for observation-only processing
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("my_processor")
|
||||
@dataclass
|
||||
class MyProcessor(ObservationProcessor):
|
||||
threshold: float = 0.5
|
||||
|
||||
def observation(self, observation):
|
||||
if observation is None:
|
||||
return observation
|
||||
# Your processing logic here
|
||||
return processed_observation
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
You now have all the tools to implement custom processors in LeRobot! The key steps are:
|
||||
|
||||
1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `feature_contract`)
|
||||
2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability
|
||||
3. **Integrate it** into a `RobotProcessor` pipeline with other processing steps
|
||||
4. **Use base classes** like `ObservationProcessor` when possible to reduce boilerplate
|
||||
|
||||
The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires. Policies like Pi0 and SmolVLA use the same normalization processors described above, so your understanding here will transfer directly when wiring policy preprocessors and postprocessors.
|
||||
|
||||
Start simple, test thoroughly, and leverage the existing helper classes to build robust data processing pipelines for your robot learning workflows.
|
||||
@@ -1,991 +0,0 @@
|
||||
# Introduction to Processors
|
||||
|
||||
In robotics, there's a fundamental mismatch between the data that robots and humans produce and what machine learning models expect. This creates several translation challenges:
|
||||
|
||||
**Raw Robot Data → Model Input:**
|
||||
|
||||
- Robots output raw sensor data (camera images, joint positions, force readings) that need normalization, batching, and device placement before models can process them
|
||||
- Language instructions from humans ("pick up the red cube") must be tokenized into numerical representations
|
||||
- Different robots use different coordinate systems and units that need standardization
|
||||
|
||||
**Model Output → Robot Commands:**
|
||||
|
||||
- Models might output end-effector positions, but robots need joint-space commands
|
||||
- Teleoperators (like gamepads) produce relative movements (delta positions), but robots expect absolute commands
|
||||
- Model predictions are often normalized and need to be converted back to real-world scales
|
||||
|
||||
**Cross-Domain Translation:**
|
||||
|
||||
- Training data from one robot setup needs adaptation for deployment on different hardware
|
||||
- Models trained with specific camera configurations must work with new camera arrangements
|
||||
- Datasets with different naming conventions need harmonization
|
||||
|
||||
**That's where processors come in.** They serve as the universal translators that bridge these gaps, ensuring seamless data flow from sensors to models to actuators.
|
||||
|
||||
Processors are the data transformation backbone of LeRobot. They handle all the preprocessing and postprocessing steps needed to convert raw environment data into model-ready inputs and vice versa. This guide will walk you through everything you need to know about processors - from basic concepts to advanced usage patterns.
|
||||
|
||||
## What are Processors?
|
||||
|
||||
In robotics, data comes in many forms - images from cameras, joint positions from sensors, text instructions from users, and more. Each type of data requires specific transformations before a model can use it effectively. Models need this data to be:
|
||||
|
||||
- **Normalized**: Scaled to appropriate ranges for neural network processing
|
||||
- **Batched**: Organized with proper dimensions for batch processing
|
||||
- **Tokenized**: Text converted to numerical representations
|
||||
- **Device-placed**: Moved to the right hardware (CPU/GPU)
|
||||
- **Type-converted**: Cast to appropriate data types
|
||||
|
||||
Processors handle these transformations through composable, reusable steps that can be chained together into pipelines. Think of them as a modular assembly line where each station performs a specific transformation on your data.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### EnvTransition: The Universal Data Container
|
||||
|
||||
The `EnvTransition` is the fundamental data structure that flows through all processors. It's a typed dictionary that represents a complete robot-environment interaction:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import TransitionKey, EnvTransition
|
||||
|
||||
# Example transition from a robot collecting data
|
||||
transition: EnvTransition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
"observation.images.camera0": camera0_image_tensor, # Shape: (H, W, C)
|
||||
"observation.images.camera1": camera1_image_tensor, # Shape: (H, W, C)
|
||||
"observation.state": joint_positions_tensor, # Shape: (7,) for 7-DOF arm
|
||||
"observation.environment_state": env_state_tensor # Shape: (3,) for object position
|
||||
},
|
||||
TransitionKey.ACTION: action_tensor, # Shape: (7,) for joint velocities
|
||||
TransitionKey.REWARD: 0.0, # Scalar reward signal
|
||||
TransitionKey.DONE: False, # Episode termination flag
|
||||
TransitionKey.TRUNCATED: False, # Episode truncation flag
|
||||
TransitionKey.INFO: {"success": False}, # Additional metadata
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"task": "pick up the red cube", # Language instruction
|
||||
"task_index": 0, # Task identifier
|
||||
"index": 42 # Frame index
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Each key in the transition has a specific purpose:
|
||||
|
||||
- **OBSERVATION**: All sensor data (images, states, proprioception)
|
||||
- **ACTION**: The action to execute or that was executed
|
||||
- **REWARD**: Reinforcement learning signal
|
||||
- **DONE/TRUNCATED**: Episode boundary indicators
|
||||
- **INFO**: Arbitrary metadata
|
||||
- **COMPLEMENTARY_DATA**: Task descriptions, indices, padding flags, inter-step data (e.g., you need to compute the velocities and then use this velocity to clip the action)
|
||||
|
||||
### ProcessorStep: The Building Block Interface
|
||||
|
||||
A `ProcessorStep` is a single transformation unit that processes transitions. It's a protocol (interface) that any processor step must implement:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import ProcessorStep, EnvTransition
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class MyProcessorStep:
|
||||
"""Example processor step interface - all methods must be implemented."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Transform the transition - this is the main processing logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Declare how this step transforms feature shapes/types."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return JSON-serializable configuration for saving/loading."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return any learnable parameters (tensors only)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load learnable parameters from saved state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset any internal state between episodes."""
|
||||
raise NotImplementedError
|
||||
```
|
||||
|
||||
### RobotProcessor: The Pipeline Orchestrator
|
||||
|
||||
The `RobotProcessor` chains multiple `ProcessorStep` instances together, executing them sequentially. It provides automatic format conversion to handle both batch dictionaries (from datasets) and EnvTransition dictionaries:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, _default_batch_to_transition, _default_transition_to_batch
|
||||
|
||||
# Create a processing pipeline
|
||||
processor = RobotProcessor(
|
||||
steps=[
|
||||
step1, # First transformation
|
||||
step2, # Second transformation
|
||||
step3 # Third transformation
|
||||
],
|
||||
name="my_preprocessing_pipeline",
|
||||
|
||||
# Optional: Custom converters for input/output formats
|
||||
to_transition=_default_batch_to_transition, # How to convert batch dict → EnvTransition
|
||||
to_output=_default_transition_to_batch # How to convert EnvTransition → output format
|
||||
)
|
||||
|
||||
# The processor automatically handles different input formats:
|
||||
# 1. If input is a batch dict (from dataset), converts to EnvTransition
|
||||
# 2. Passes through each step sequentially
|
||||
# 3. Converts back to original format (or custom output format)
|
||||
|
||||
# Example with batch dict input (common in training)
|
||||
batch_dict = {"observation.state": tensor, "action": tensor}
|
||||
output = processor(batch_dict) # Automatically converted to/from EnvTransition
|
||||
|
||||
# Example with EnvTransition input (common in inference)
|
||||
transition = {TransitionKey.OBSERVATION: {...}, TransitionKey.ACTION: ...}
|
||||
output = processor(transition) # Stays as EnvTransition throughout
|
||||
```
|
||||
|
||||
The `to_transition` and `to_output` converters enable seamless integration with existing codebases.
|
||||
By default, they handle the standard LeRobot batch format, but you can customize them for different data structures.
|
||||
|
||||
### Additional Converter Functions
|
||||
|
||||
LeRobot provides several specialized converter functions for common robotics scenarios:
|
||||
|
||||
```python
|
||||
from lerobot.processor.converters import (
|
||||
to_transition_teleop_action,
|
||||
to_transition_robot_observation,
|
||||
to_output_robot_action,
|
||||
to_dataset_frame
|
||||
)
|
||||
```
|
||||
|
||||
**`to_transition_teleop_action`** - Converts teleoperation device actions to EnvTransitions:
|
||||
|
||||
```python
|
||||
# Use case: Phone, gamepad, or other teleop device control
|
||||
phone_action = {"x": 0.1, "y": -0.2, "gripper": 0.8}
|
||||
transition = to_transition_teleop_action(phone_action)
|
||||
# Creates: {ACTION: {"action.x": 0.1, "action.y": -0.2, "action.gripper": 0.8}, ...}
|
||||
```
|
||||
|
||||
**`to_transition_robot_observation`** - Converts robot sensor data to EnvTransitions:
|
||||
|
||||
```python
|
||||
# Use case: Live robot observation during inference
|
||||
robot_obs = {
|
||||
"joint_1": 0.5, "joint_2": -0.3, # joint positions
|
||||
"camera_0": image_array # camera images
|
||||
}
|
||||
transition = to_transition_robot_observation(robot_obs)
|
||||
# Creates: {OBSERVATION: {"observation.state.joint_1": 0.5, "observation.images.camera_0": image, ...}}
|
||||
```
|
||||
|
||||
**`to_output_robot_action`** - Extracts robot-executable actions from EnvTransitions:
|
||||
|
||||
```python
|
||||
# Use case: Converting model outputs back to robot commands
|
||||
model_transition = {ACTION: {"action.joint_1": 0.2, "action.joint_2": 0.1}}
|
||||
robot_action = to_output_robot_action(model_transition)
|
||||
# Returns: {"joint_1": 0.2, "joint_2": 0.1} - ready for robot.send_action()
|
||||
```
|
||||
|
||||
**`to_dataset_frame`** - Converts transitions to dataset-compatible format:
|
||||
|
||||
```python
|
||||
# Use case: Saving processed data or creating training batches
|
||||
features = {
|
||||
"action": {"names": ["joint_1", "joint_2"]},
|
||||
"observation.state": {"names": ["joint_1", "joint_2"]},
|
||||
"observation.images.camera0": {...}
|
||||
}
|
||||
batch = to_dataset_frame(transition, features)
|
||||
# Returns: {"action": [0.2, 0.1], "observation.state": [0.5, -0.3], ...}
|
||||
```
|
||||
|
||||
These converters are particularly useful when integrating with real robots, as shown in the examples:
|
||||
|
||||
```python
|
||||
# Example from phone_so100_teleop.py - Real robot teleoperation
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[...],
|
||||
to_transition=to_transition_teleop_action, # Phone → EnvTransition
|
||||
to_output=lambda tr: tr # Keep as EnvTransition
|
||||
)
|
||||
|
||||
# Example from phone_so100_eval.py - Robot action execution
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[...],
|
||||
to_transition=lambda tr: tr, # Already EnvTransition
|
||||
to_output=to_output_robot_action # EnvTransition → Robot action
|
||||
)
|
||||
|
||||
# Example from phone_so100_record.py - Dataset recording
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[...],
|
||||
to_transition=to_transition_robot_observation, # Robot obs → EnvTransition
|
||||
to_output=lambda tr: tr # Keep as EnvTransition for dataset
|
||||
)
|
||||
```
|
||||
|
||||
### Data Format Conversion
|
||||
|
||||
Different data sources have different formats, but processors need a unified `EnvTransition` structure internally.
|
||||
The default converters handle LeRobot datasets, but you can customize them:
|
||||
|
||||
```python
|
||||
# Default: LeRobot batch format
|
||||
lerobot_batch = {
|
||||
"observation.state": torch.tensor(...),
|
||||
"action": torch.tensor(...),
|
||||
"next.reward": torch.tensor(...),
|
||||
"task": ["pick cube", ...]
|
||||
}
|
||||
# → Converts to EnvTransition → Processes → Converts back
|
||||
|
||||
# Custom: Live robot data
|
||||
robot_data = {
|
||||
"cameras": {"wrist_cam": np.array(...)},
|
||||
"joint_positions": np.array(...),
|
||||
"gripper_state": 0.5
|
||||
}
|
||||
|
||||
def robot_to_transition(data: dict) -> EnvTransition:
|
||||
return {
|
||||
TransitionKey.OBSERVATION: {
|
||||
"observation.images.wrist": torch.from_numpy(data["cameras"]["wrist_cam"]),
|
||||
"observation.state": torch.from_numpy(data["joint_positions"])
|
||||
},
|
||||
TransitionKey.ACTION: None,
|
||||
# ... other fields with defaults
|
||||
}
|
||||
|
||||
# Use custom converter
|
||||
processor = RobotProcessor(
|
||||
steps=[...],
|
||||
to_transition=robot_to_transition,
|
||||
to_output=lambda transition: transition # Keep as EnvTransition
|
||||
)
|
||||
```
|
||||
|
||||
**When to customize:** Live robot data, Gymnasium environments, legacy datasets, or any non-LeRobot format.
|
||||
|
||||
## Common Processor Steps
|
||||
|
||||
LeRobot provides a rich set of pre-built processor steps for common transformations.
|
||||
Let's explore each in detail:
|
||||
|
||||
### Data Normalization
|
||||
|
||||
Normalization is crucial for neural network training and inference.
|
||||
The `NormalizerProcessor` handles both mean-std normalization and min-max scaling:
|
||||
|
||||
```python
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
|
||||
|
||||
# Define what features exist in your data
|
||||
features = {
|
||||
"observation.images.camera0": PolicyFeature(
|
||||
type=FeatureType.IMAGE,
|
||||
shape=(224, 224, 3)
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,)
|
||||
),
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,)
|
||||
)
|
||||
}
|
||||
|
||||
# Define normalization strategy per feature type
|
||||
norm_map = {
|
||||
FeatureType.IMAGE: NormalizationMode.MEAN_STD, # Images: (x - mean) / std
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX, # States: scale to [-1, 1]
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX # Actions: scale to [-1, 1]
|
||||
}
|
||||
|
||||
# Create normalizer with dataset statistics
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats, # Contains mean, std, min, max per feature
|
||||
normalize_keys={"observation.state", "action"} # Optional: only normalize specific keys
|
||||
)
|
||||
|
||||
# For postprocessing: inverse transformation
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
# The normalizer automatically:
|
||||
# - Detects which normalization to apply based on feature type
|
||||
# - Handles device placement of statistics tensors
|
||||
# - Skips keys not in stats or not in normalize_keys
|
||||
# - Adds metadata about what was normalized
|
||||
```
|
||||
|
||||
### Device Management
|
||||
|
||||
The `DeviceProcessor` ensures tensors are on the right device with the right dtype:
|
||||
|
||||
```python
|
||||
from lerobot.processor.device_processor import DeviceProcessor
|
||||
|
||||
# Basic GPU placement
|
||||
gpu_processor = DeviceProcessor(device="cuda:0")
|
||||
|
||||
# Advanced: GPU with half-precision for inference
|
||||
efficient_processor = DeviceProcessor(
|
||||
device="cuda:0",
|
||||
float_dtype="float16" # Convert float32 -> float16 for memory efficiency
|
||||
)
|
||||
|
||||
# The processor:
|
||||
# - Moves all tensors to specified device
|
||||
# - Preserves non-tensor data unchanged
|
||||
# - Optionally converts float dtypes while preserving int/bool types
|
||||
# - Uses non_blocking transfers for CUDA devices
|
||||
# - Handles nested structures (observations, complementary_data)
|
||||
|
||||
# Supported float dtypes:
|
||||
# "float16" / "half": 16-bit floating point
|
||||
# "float32" / "float": 32-bit floating point (default)
|
||||
# "float64" / "double": 64-bit floating point
|
||||
# "bfloat16": Brain floating point (better for training)
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
Models expect batched inputs, but robot interactions often produce unbatched data:
|
||||
|
||||
```python
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
|
||||
batch_processor = ToBatchProcessor()
|
||||
|
||||
# Automatically adds batch dimensions where needed:
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
|
||||
# The processor intelligently:
|
||||
# - Detects tensor dimensionality
|
||||
# - Adds batch dim to 1D states/actions
|
||||
# - Adds batch dim to 3D images
|
||||
# - Wraps string tasks in lists
|
||||
# - Preserves already-batched data
|
||||
|
||||
# Example usage in inference:
|
||||
single_observation = robot.get_observation() # Unbatched
|
||||
batched_input = batch_processor({"observation": single_observation})
|
||||
model_output = model(batched_input) # Model expects batch dim
|
||||
```
|
||||
|
||||
### Text Tokenization
|
||||
|
||||
For language-conditioned policies, text instructions must be tokenized:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Option 1: Auto-load tokenizer by name
|
||||
tokenizer_proc = TokenizerProcessor(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=128,
|
||||
task_key="task", # Where to find text in complementary_data
|
||||
padding="max_length", # Pad to max_length
|
||||
padding_side="right",
|
||||
truncation=True # Truncate if longer than max_length
|
||||
)
|
||||
|
||||
# Option 2: Provide custom tokenizer
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
||||
custom_proc = TokenizerProcessor(
|
||||
tokenizer=custom_tokenizer,
|
||||
max_length=256,
|
||||
padding_side="left" # For autoregressive models
|
||||
)
|
||||
|
||||
# The processor:
|
||||
# - Extracts task text from complementary_data
|
||||
# - Tokenizes using HuggingFace tokenizer
|
||||
# - Adds tokens and attention_mask to observations
|
||||
# - Handles both single strings and lists of strings
|
||||
# - Preserves original task in complementary_data
|
||||
|
||||
# Output structure:
|
||||
# observation["observation.language.tokens"] = tensor([101, 2032, ...])
|
||||
# observation["observation.language.attention_mask"] = tensor([1, 1, 0, ...])
|
||||
```
|
||||
|
||||
### Key Renaming
|
||||
|
||||
Different datasets and models may use different naming conventions.
|
||||
The `RenameProcessor` solves this mismatch:
|
||||
|
||||
**Why is this useful?**
|
||||
|
||||
- When loading a model trained on a different dataset with different key names
|
||||
- When using foundation models that expect specific key naming conventions
|
||||
- When standardizing datasets from different sources
|
||||
- When adapting legacy code to new naming standards
|
||||
|
||||
```python
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
# Example 1: Dataset uses "top"/"wrist", model expects "camera0"/"camera1"
|
||||
rename_proc = RenameProcessor(
|
||||
rename_map={
|
||||
"observation.images.top": "observation.images.camera0",
|
||||
"observation.images.wrist": "observation.images.camera1",
|
||||
}
|
||||
)
|
||||
|
||||
# Example 2: Foundation model compatibility
|
||||
# Your dataset: "observation.state", Foundation model: "proprio"
|
||||
foundation_rename = RenameProcessor(
|
||||
rename_map={
|
||||
"observation.state": "proprio",
|
||||
"observation.images.main": "rgb",
|
||||
}
|
||||
)
|
||||
|
||||
# Example 3: Standardizing multiple datasets
|
||||
standardize_rename = RenameProcessor(
|
||||
rename_map={
|
||||
# Different robots might use different names
|
||||
"observation.joint_positions": "observation.state",
|
||||
"observation.gripper_state": "observation.end_effector",
|
||||
"observation.arm_camera": "observation.images.wrist",
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Building Complete Pipelines
|
||||
|
||||
Let's build a real-world preprocessing and postprocessing pipeline for a vision-based
|
||||
manipulation policy:
|
||||
|
||||
```python
|
||||
# Consolidated imports
|
||||
from lerobot.processor import (
|
||||
RobotProcessor,
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
DeviceProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
RenameProcessor
|
||||
)
|
||||
|
||||
# Step 1: Define the preprocessing pipeline
|
||||
preprocessor = RobotProcessor(
|
||||
steps=[
|
||||
# 1. Standardize naming from dataset
|
||||
RenameProcessor(
|
||||
rename_map={
|
||||
"observation.images.top": "observation.images.camera0",
|
||||
"observation.images.wrist": "observation.images.camera1"
|
||||
}
|
||||
),
|
||||
|
||||
# 2. Add batch dimensions for model
|
||||
ToBatchProcessor(),
|
||||
|
||||
# 3. Tokenize language instructions if present
|
||||
TokenizerProcessor(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=64,
|
||||
task_key="task"
|
||||
),
|
||||
|
||||
# 4. Normalize numerical data
|
||||
NormalizerProcessor(
|
||||
features=policy_features,
|
||||
norm_map={
|
||||
FeatureType.IMAGE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX
|
||||
},
|
||||
stats=dataset.meta.stats
|
||||
),
|
||||
|
||||
# 5. Move to GPU and convert to half precision
|
||||
DeviceProcessor(
|
||||
device="cuda:0",
|
||||
float_dtype="float16"
|
||||
)
|
||||
],
|
||||
name="robot_preprocessor"
|
||||
)
|
||||
|
||||
# Step 2: Define the postprocessing pipeline
|
||||
postprocessor = RobotProcessor(
|
||||
steps=[
|
||||
# 1. Move back to CPU for robot hardware
|
||||
DeviceProcessor(device="cpu"),
|
||||
|
||||
# 2. Denormalize actions to original scale
|
||||
UnnormalizerProcessor(
|
||||
features=policy_features,
|
||||
norm_map={
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX
|
||||
},
|
||||
stats=dataset.meta.stats
|
||||
)
|
||||
],
|
||||
name="robot_postprocessor"
|
||||
)
|
||||
```
|
||||
|
||||
## Using Processors in Practice
|
||||
|
||||
### Training Loop Integration
|
||||
|
||||
Here's how processors integrate into a training loop using the policy's forward method:
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = LeRobotDataset(repo_id="your_dataset")
|
||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
|
||||
# Initialize model and processors
|
||||
model = YourPolicy.from_pretrained("your_model")
|
||||
preprocessor = RobotProcessor.from_pretrained(
|
||||
"your_model",
|
||||
config_filename="robot_preprocessor.json"
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
# Preprocess batch
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Forward pass - returns loss and optional metrics
|
||||
loss, metrics = model.forward(processed_batch)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log metrics if available
|
||||
if metrics:
|
||||
wandb.log(metrics)
|
||||
```
|
||||
|
||||
### Inference Pipeline
|
||||
|
||||
For deployment, processors ensure consistent data handling with real robots:
|
||||
|
||||
```python
|
||||
# Load model and processors
|
||||
policy = YourPolicy.from_pretrained("path/to/model")
|
||||
preprocessor = RobotProcessor.from_pretrained(
|
||||
"path/to/model",
|
||||
config_filename="robot_preprocessor.json"
|
||||
)
|
||||
postprocessor = RobotProcessor.from_pretrained(
|
||||
"path/to/model",
|
||||
config_filename="robot_postprocessor.json"
|
||||
)
|
||||
|
||||
# Connect to robot
|
||||
robot = make_robot_from_config(robot_config)
|
||||
robot.connect()
|
||||
|
||||
# Inference loop
|
||||
policy.eval()
|
||||
# Reset the policy and processors
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
while not done:
|
||||
# Get observation from robot
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Build dataset-compatible frame
|
||||
observation_frame = build_dataset_frame(
|
||||
dataset.features,
|
||||
observation,
|
||||
prefix="observation"
|
||||
)
|
||||
|
||||
# Add task instruction to complementary data
|
||||
observation_frame["task"] = "pick up the red cube"
|
||||
|
||||
# Preprocess for model
|
||||
model_input = preprocessor(observation_frame)
|
||||
|
||||
# Run policy
|
||||
raw_action = policy.select_action(model_input)
|
||||
|
||||
# Postprocess action
|
||||
action_transition = {TransitionKey.ACTION: raw_action}
|
||||
processed = postprocessor(action_transition)
|
||||
action = processed[TransitionKey.ACTION]
|
||||
|
||||
# Convert to robot action format
|
||||
robot_action = {
|
||||
key: action[i].item()
|
||||
for i, key in enumerate(robot.action_features)
|
||||
}
|
||||
|
||||
# Execute on robot
|
||||
robot.send_action(robot_action)
|
||||
```
|
||||
|
||||
## Saving and Loading Processors
|
||||
|
||||
Processors can be persisted and shared just like models, making them portable across different
|
||||
environments and ensuring reproducibility:
|
||||
|
||||
### Local Save/Load
|
||||
|
||||
```python
|
||||
# Save processor configuration and state
|
||||
preprocessor.save_pretrained(
|
||||
"./my_robot_processor",
|
||||
config_filename="preprocessor.json" # Optional custom name
|
||||
)
|
||||
|
||||
# The save creates:
|
||||
# my_robot_processor/
|
||||
# ├── preprocessor.json # Configuration
|
||||
# ├── preprocessor_step_0_normalizer.safetensors # Step 0 state (stats)
|
||||
# └── preprocessor_step_1_device.safetensors # Step 1 state (if any)
|
||||
|
||||
# Load processor
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
"./my_robot_processor",
|
||||
config_filename="preprocessor.json"
|
||||
)
|
||||
```
|
||||
|
||||
### HuggingFace Hub Integration
|
||||
|
||||
The HuggingFace Hub provides a centralized place to share and version your processors.
|
||||
This is particularly useful for sharing preprocessing configurations with models,
|
||||
ensuring that anyone who downloads your model can reproduce your exact preprocessing pipeline.
|
||||
It also enables versioning and collaboration on preprocessing strategies.
|
||||
|
||||
```python
|
||||
# Save to HuggingFace Hub
|
||||
preprocessor.save_pretrained("username/my-robot-policy")
|
||||
|
||||
# Load from Hub with automatic download
|
||||
hub_processor = RobotProcessor.from_pretrained(
|
||||
"username/my-robot-policy",
|
||||
config_filename="robot_preprocessor.json",
|
||||
revision="main", # Optional: specific revision
|
||||
cache_dir="./cache" # Optional: local cache directory
|
||||
)
|
||||
|
||||
# The Hub integration provides:
|
||||
# - Automatic versioning with git
|
||||
# - Public or private sharing
|
||||
# - Download caching for efficiency
|
||||
# - Integration with model repositories
|
||||
```
|
||||
|
||||
### Loading with Overrides
|
||||
|
||||
Sometimes you need to modify loaded processors for new environments or datasets.
|
||||
The override mechanism allows you to update specific processor configurations without modifying
|
||||
the saved files:
|
||||
|
||||
```python
|
||||
# Load processor with configuration overrides
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"./saved_processor",
|
||||
overrides={
|
||||
# Change device for different hardware
|
||||
"device_processor": {"device": "cuda:1"},
|
||||
|
||||
# Update statistics for new dataset
|
||||
"normalizer_processor": {"stats": new_dataset.meta.stats},
|
||||
|
||||
# Provide non-serializable objects (like tokenizers)
|
||||
"tokenizer_processor": {"tokenizer": custom_tokenizer}
|
||||
}
|
||||
)
|
||||
|
||||
# Common override scenarios:
|
||||
# 1. Adapting to different hardware (GPU availability)
|
||||
# 2. Fine-tuning on new datasets with different statistics
|
||||
# 3. Providing runtime dependencies that can't be serialized
|
||||
# 4. Testing variations without creating new saved configs
|
||||
```
|
||||
|
||||
## Creating Custom Processor Steps
|
||||
|
||||
Build your own processor steps for specialized transformations.
|
||||
The key is implementing the required interface:
|
||||
|
||||
### Basic Custom Step with Registration
|
||||
|
||||
The registration mechanism allows your custom processors to be saved and loaded by name rather
|
||||
than by module path.
|
||||
This makes them more portable and easier to share:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from lerobot.processor.pipeline import ProcessorStepRegistry, ObservationProcessor
|
||||
|
||||
# The @register decorator adds your processor to the global registry
|
||||
# Use a unique name, preferably namespaced to avoid conflicts
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("my_company/gaussian_noise")
|
||||
class GaussianNoiseProcessor(ObservationProcessor):
|
||||
"""Add Gaussian noise to observations for robustness training."""
|
||||
|
||||
noise_std: float = 0.01
|
||||
training_only: bool = True
|
||||
is_training: bool = True
|
||||
|
||||
def observation(self, observation):
|
||||
"""Add noise to observation tensors."""
|
||||
if not self.is_training and self.training_only:
|
||||
return observation
|
||||
|
||||
noisy_obs = {}
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor) and "image" not in key:
|
||||
# Add noise to non-image observations
|
||||
noise = torch.randn_like(value) * self.noise_std
|
||||
noisy_obs[key] = value + noise
|
||||
else:
|
||||
noisy_obs[key] = value
|
||||
|
||||
return noisy_obs
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"noise_std": self.noise_std,
|
||||
"training_only": self.training_only,
|
||||
"is_training": self.is_training
|
||||
}
|
||||
|
||||
# Why register?
|
||||
# 1. Enables saving by name: config saves "my_company/gaussian_noise" instead of full module path
|
||||
# 2. More portable: Others can use your processor without your exact module structure
|
||||
# 3. Version-safe: Module refactoring won't break saved configs
|
||||
# 4. Cleaner configs: JSON shows readable names instead of long import paths
|
||||
```
|
||||
|
||||
### Using Base Classes for Common Patterns
|
||||
|
||||
LeRobot provides base classes like `ObservationProcessor`, `ActionProcessor`, etc., that handle
|
||||
the boilerplate of extracting and reinserting specific components:
|
||||
|
||||
```python
|
||||
from lerobot.processor import ActionProcessor
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("my_company/action_clipper")
|
||||
class ActionClipProcessor(ActionProcessor):
|
||||
"""Clip actions to safe ranges."""
|
||||
|
||||
min_value: float = -1.0
|
||||
max_value: float = 1.0
|
||||
|
||||
def action(self, action):
|
||||
"""Process only the action component."""
|
||||
# No need to handle transition dict - base class does it
|
||||
return torch.clamp(action, self.min_value, self.max_value)
|
||||
|
||||
def get_config(self):
|
||||
return {"min_value": self.min_value, "max_value": self.max_value}
|
||||
```
|
||||
|
||||
For more advanced processor patterns including stateful processors, see [Implement Your Own Processor](implement_your_own_processor.mdx).
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Debugging with Hooks
|
||||
|
||||
Processors support hooks for monitoring and debugging without modifying the pipeline code:
|
||||
|
||||
```python
|
||||
# Define monitoring hooks
|
||||
def log_shapes(step_idx: int, transition: EnvTransition):
|
||||
"""Log tensor shapes after each step."""
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
print(f"Step {step_idx} shapes:")
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key}: {value.shape}")
|
||||
|
||||
def check_nans(step_idx: int, transition: EnvTransition):
|
||||
"""Check for NaN values."""
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
|
||||
print(f"Warning: NaN detected in {key} at step {step_idx}")
|
||||
|
||||
# Register hooks
|
||||
processor.register_after_step_hook(log_shapes)
|
||||
processor.register_after_step_hook(check_nans)
|
||||
|
||||
# Process data - hooks will be called after each step
|
||||
output = processor(input_data)
|
||||
|
||||
# Remove hooks when done debugging
|
||||
processor.unregister_after_step_hook(log_shapes)
|
||||
processor.unregister_after_step_hook(check_nans)
|
||||
```
|
||||
|
||||
### Step-by-Step Inspection
|
||||
|
||||
Use `step_through()` for detailed debugging of the transformation pipeline:
|
||||
|
||||
```python
|
||||
# Inspect data at each transformation stage
|
||||
for i, intermediate in enumerate(processor.step_through(data)):
|
||||
print(f"\n=== After step {i} ===")
|
||||
|
||||
# Check observation shapes
|
||||
obs = intermediate.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f"{key}: shape={value.shape}, "
|
||||
f"dtype={value.dtype}, "
|
||||
f"device={value.device}, "
|
||||
f"range=[{value.min():.3f}, {value.max():.3f}]")
|
||||
|
||||
# Check action if present
|
||||
action = intermediate.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
print(f"action: shape={action.shape}, range=[{action.min():.3f}, {action.max():.3f}]")
|
||||
```
|
||||
|
||||
### Pipeline Slicing
|
||||
|
||||
Extract subsets of a pipeline for testing or creating variations:
|
||||
|
||||
```python
|
||||
# Get specific steps
|
||||
first_three_steps = processor[:3] # Returns new RobotProcessor
|
||||
middle_step = processor[2] # Returns single ProcessorStep
|
||||
|
||||
# Test individual steps
|
||||
test_input = {...}
|
||||
step_output = processor[0](test_input) # Test first step only
|
||||
|
||||
# Create variations
|
||||
variant_processor = RobotProcessor(
|
||||
steps=processor.steps[:-1] + [new_final_step],
|
||||
name="variant"
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices and Tips
|
||||
|
||||
### 1. Order Matters
|
||||
|
||||
The sequence of processors is crucial. Follow this general order:
|
||||
|
||||
```python
|
||||
# Preprocessing: Raw → Model-ready
|
||||
1. Rename (standardize keys)
|
||||
2. Batch (add dimensions)
|
||||
3. Tokenize (text → tokens)
|
||||
4. Normalize (scale values)
|
||||
5. Device (move to GPU)
|
||||
|
||||
# Postprocessing: Model → Robot-ready
|
||||
1. Device (move to CPU)
|
||||
2. Unnormalize (restore scale)
|
||||
3. Unbatch (remove dimensions if needed)
|
||||
```
|
||||
|
||||
### 2. Registration Best Practices
|
||||
|
||||
```python
|
||||
# Always register custom steps for better portability
|
||||
@ProcessorStepRegistry.register("my_company/special_processor")
|
||||
class SpecialProcessor:
|
||||
...
|
||||
|
||||
# Use namespaced names to avoid conflicts
|
||||
# Good: "my_company/augmentation"
|
||||
# Bad: "augmentation" (too generic)
|
||||
|
||||
# Check registered processors
|
||||
print(ProcessorStepRegistry.list()) # See all registered processors
|
||||
```
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
**Tensor Device Mismatch:**
|
||||
|
||||
```python
|
||||
# Problem: RuntimeError: Expected all tensors on same device
|
||||
# Solution: Ensure DeviceProcessor is in pipeline
|
||||
preprocessor = RobotProcessor(
|
||||
steps=[
|
||||
NormalizerProcessor(...),
|
||||
DeviceProcessor(device="cuda") # Add this
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
**Missing Statistics:**
|
||||
|
||||
```python
|
||||
# Problem: NormalizerProcessor has no stats
|
||||
# Solution 1: Compute stats from dataset
|
||||
from lerobot.datasets.compute_stats import compute_stats
|
||||
stats = compute_stats(dataset)
|
||||
|
||||
# Solution 2: Load with overrides
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"model_path",
|
||||
overrides={"normalizer_processor": {"stats": dataset.meta.stats}}
|
||||
)
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
Now that you understand processors, explore these topics:
|
||||
|
||||
- [**Implement Your Own Processor**](implement_your_own_processor.mdx) - Deep dive into creating custom processors with advanced features like stateful processing
|
||||
- [**Policy Documentation**](policies.mdx) - Learn how different policies use processors
|
||||
- [**Dataset Documentation**](datasets.mdx) - Understand the data format that processors transform
|
||||
- [**Training Guide**](training.mdx) - See processors in action during model training
|
||||
- [**Evaluation Guide**](evaluation.mdx) - Learn about processor usage during policy evaluation
|
||||
|
||||
## Summary
|
||||
|
||||
Processors are the unsung heroes of robotics pipelines, handling the critical transformations between raw sensor data and model-ready tensors. By understanding and effectively using processors, you can:
|
||||
|
||||
- Build robust, reusable data pipelines
|
||||
- Share preprocessing configurations across projects
|
||||
- Debug data transformations systematically
|
||||
- Ensure consistency between training and deployment
|
||||
- Create custom transformations for specialized tasks
|
||||
|
||||
Remember: good preprocessing is often the difference between a model that works in theory
|
||||
and one that works in practice!
|
||||
The modular pipeline approach ensures your transformations are testable, reproducible,
|
||||
and portable across different robots and environments.
|
||||
@@ -31,7 +31,7 @@ pip install -e ".[dynamixel]"
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s
|
||||
You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=lekiwi \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=lekiwi \
|
||||
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
|
||||
```
|
||||
@@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
# LIBERO
|
||||
|
||||
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots won’t just be pretrained once in a factory, they’ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and it’s a key step toward building robots that become truly personalized helpers. The benchmark was first introduced in the [LIBERO paper](https://arxiv.org/abs/2306.03310) and the [original repository](https://github.com/Lifelong-Robot-Learning/LIBERO).
|
||||
|
||||
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other’s work.
|
||||
|
||||
LIBERO includes **five task suites**:
|
||||
|
||||
- **LIBERO-Spatial (`libero_spatial`)** – tasks that require reasoning about spatial relations.
|
||||
- **LIBERO-Object (`libero_object`)** – tasks centered on manipulating different objects.
|
||||
- **LIBERO-Goal (`libero_goal`)** – goal-conditioned tasks where the robot must adapt to changing targets.
|
||||
- **LIBERO-90 (`libero_90`)** – 90 short-horizon tasks from the LIBERO-100 collection.
|
||||
- **LIBERO-Long (`libero_10`)** – 10 long-horizon tasks from the LIBERO-100 collection.
|
||||
|
||||
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
|
||||
|
||||

|
||||
_Figure 1: An overview of the LIBERO benchmark._
|
||||
|
||||
## Evaluating with LIBERO
|
||||
|
||||
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it primarily to **benchmark [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model, comparing it against state-of-the-art VLA models such as Pi0, OpenVLA, Octo, and Diffusion Policy.
|
||||
|
||||
LIBERO is now part of our **multi-eval supported simulation**, allowing you to benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a single flag.
|
||||
|
||||
To install LIBERO, first follow the [LeRobot Installation Guide](https://huggingface.co/docs/lerobot/installation).
|
||||
Once LeRobot is installed, there are two options:
|
||||
|
||||
1. **Install via pip** (recommended):
|
||||
|
||||
```bash
|
||||
pip install "lerobot[libero,smolvla]"
|
||||
```
|
||||
|
||||
2. **Install from source**:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e ".[libero,smolvla]"
|
||||
```
|
||||
|
||||
### Single-suite evaluation
|
||||
|
||||
Evaluate a policy on one LIBERO suite:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.multitask_eval=False \
|
||||
--eval.batch_size=2 \
|
||||
--eval.n_episodes=3
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||
|
||||
---
|
||||
|
||||
### Multi-suite evaluation
|
||||
|
||||
Benchmark a policy across multiple suites at once:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.multitask_eval=True \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=2
|
||||
```
|
||||
|
||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||
- Set `-env.multitask_eval=True` to enable evaluation across all tasks in those suites.
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||
|
||||
- **Observations**
|
||||
- `observation.state` – proprioceptive features (agent state).
|
||||
- `observation.images.image` – main camera view (`agentview_image`).
|
||||
- `observation.images.image2` – wrist camera view (`robot0_eye_in_hand_image`).
|
||||
|
||||
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any visual features. Make sure your dataset metadata keys match this convention when evaluating.
|
||||
|
||||
## Input Features and Metadata Alignment
|
||||
|
||||
To train or evaluate a policy, you use `make_policy`, which builds a feature-naming dictionary for the observations the policy expects.
|
||||
This mapping can come from:
|
||||
- Dataset metadata
|
||||
- The evaluation environment
|
||||
- The policy path (if a pretrained repo ID is provided)
|
||||
|
||||
### Common Issues
|
||||
|
||||
A common problem is when the keys in the dataset, environment, and policy config do not match. For example:
|
||||
- `wrist_image` vs `observation.images.image2`
|
||||
- `observation.image2` (as in SmolVLA) vs the `.images.*` prefix convention
|
||||
|
||||
Such mismatches will cause `KeyError`s. This may be due to assumptions in `make_policy` or missing error handling.
|
||||
|
||||
***
|
||||
|
||||
### How to Check Expected Features
|
||||
- Open your policy config (`config.json`), e.g. [example here](https://huggingface.co/jadechoghari/smolvla-libero/blob/main/config.json).
|
||||
- Or add a breakpoint in `train.py` and inspect:
|
||||
|
||||
````python
|
||||
print(policy.config.input_features)
|
||||
To ensure you can just check what your policy expects as `input_features`:
|
||||
|
||||
- Open your policy config (`config.json`), e.g. [example here](https://huggingface.co/jadechoghari/smolvla-libero/blob/main/config.json).
|
||||
- Or add a breakpoint in `train.py` and inspect:
|
||||
```python
|
||||
print(policy.config.input_features)
|
||||
Fixing KeyErrors (Preprocessing Example)
|
||||
````
|
||||
|
||||
## Fixing KeyErrors (Preprocessing Example)
|
||||
|
||||
If your dataset columns do not follow the expected naming, you can rename them in-place before training:
|
||||
|
||||
````python
|
||||
import pyarrow.parquet as pq
|
||||
import shutil
|
||||
|
||||
def rename_columns(parquet_path, rename_map):
|
||||
table = pq.read_table(parquet_path)
|
||||
schema = table.schema
|
||||
new_names = [rename_map.get(name, name) for name in schema.names]
|
||||
renamed_table = table.rename_columns(new_names)
|
||||
backup_path = parquet_path + ".bak"
|
||||
shutil.copy(parquet_path, backup_path)
|
||||
pq.write_table(renamed_table, parquet_path)
|
||||
print(f"patched {parquet_path}, backup at {backup_path}")
|
||||
|
||||
# example mapping: align dataset keys to LeRobot convention
|
||||
rename_map = {
|
||||
"image": "observation.images.image",
|
||||
"wrist_image": "observation.images.image2",
|
||||
}
|
||||
|
||||
rename_columns("episode_000001.parquet", rename_map)
|
||||
|
||||
|
||||
|
||||
- **Actions**
|
||||
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
|
||||
|
||||
We also provide a notebook for quick testing:
|
||||
Training with LIBERO
|
||||
|
||||
## Training with LIBERO
|
||||
|
||||
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
|
||||
|
||||
The environment expects:
|
||||
|
||||
- `observation.state` → 8-dim agent state
|
||||
- `observation.images.image` → main camera (`agentview_image`)
|
||||
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
|
||||
|
||||
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code. We plan to provide a script to easily preprocess such data.
|
||||
To avoid potential mismatches and `KeyError`s, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulations.
|
||||
|
||||
- 🔗 [Preprocessed LIBERO dataset (Hugging Face LeRobot org)](https://huggingface.co/datasets/HuggingFaceVLA/libero)
|
||||
- 🔗 [Original LIBERO dataset (physical-intelligence)](https://huggingface.co/datasets/physical-intelligence/libero)
|
||||
|
||||
The preprocessed dataset follows LeRobot naming conventions (e.g., `.images.*` prefix for visual features) and aligns with policy configs out-of-the-box.
|
||||
The original dataset is acknowledged here as the primary source.
|
||||
---
|
||||
|
||||
### Example training command
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/train.py \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/libero-test \
|
||||
--dataset.repo_id=jadechoghari/smol-libero3 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--output_dir=./outputs/ \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--env.multitask_eval=True \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000 \
|
||||
````
|
||||
|
||||
---
|
||||
|
||||
### Note on rendering
|
||||
|
||||
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||
|
||||
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||
|
||||
---
|
||||
|
||||
## Colab Note on Parallel Evaluation
|
||||
|
||||
When running evaluation on Colab, you may encounter warnings such as:
|
||||
|
||||
```
|
||||
UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
|
||||
```
|
||||
|
||||
This happens because Colab’s rendering contexts are **not thread-safe**, and `ThreadPoolExecutor(max_workers=num_workers)` can trigger segfaults or leaked semaphore warnings.
|
||||
|
||||
**Colab Note:**
|
||||
Parallel evaluation is not supported in Colab. To avoid these issues, run sequentially or disable multitask evaluation:
|
||||
|
||||
Run sequentially:
|
||||
|
||||
```bash
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
Or disable multitask evaluation:
|
||||
|
||||
```bash
|
||||
--env.multitask_eval=False
|
||||
```
|
||||
|
||||
If you want to take advantage of **parallel evaluation**, we recommend **not using Colab**. Instead, run locally or on a proper compute environment where multi-threaded rendering is easily supported.
|
||||
@@ -1,195 +0,0 @@
|
||||
# Phone
|
||||
|
||||
Use your phone (iOS or Android) to control your robot.
|
||||
|
||||
**In this guide you'll learn:**
|
||||
|
||||
- How to connect an iOS/Android phone
|
||||
- How phone pose is mapped to robot end‑effector (EE) targets
|
||||
- How to tweak safety limits, gripper control, and IK settings
|
||||
|
||||
To use phone to control your robot, install the relevant dependencies with:
|
||||
|
||||
```bash
|
||||
pip install lerobot[phone]
|
||||
```
|
||||
|
||||
## Get started
|
||||
|
||||
### Supported platforms
|
||||
|
||||
- iOS: Uses the HEBI Mobile I/O app (ARKit pose + buttons). Download the app first, open it and the examples will discover it on your network and stream the phone pose and inputs.
|
||||
- Android: Uses the `teleop` package (WebXR). When you start the Python process, it prints a local URL. Open the link on your phone, tap Start, then use Move to stream pose.
|
||||
|
||||
Links:
|
||||
|
||||
- Android WebXR library: [`teleop` on PyPI](https://pypi.org/project/teleop/)
|
||||
- iOS app: [HEBI Mobile I/O](https://docs.hebi.us/tools.html#mobile-io)
|
||||
|
||||
### Phone orientation and controls
|
||||
|
||||
- Orientation: hold the phone with the screen facing up and the top edge pointing in the same direction as the robot gripper. This ensures calibration aligns the phone’s frame with the robot frame so motion feels natural.
|
||||
- Enable/disable:
|
||||
- iOS: Hold `B1` to enable teleoperation, release to stop. The first press captures a reference pose.
|
||||
- Android: Press and hold the `Move` button, release to stop. The first press captures a reference pose.
|
||||
- Gripper control:
|
||||
- iOS: Analog input `A3` controls the gripper as velocity input.
|
||||
- Android: Buttons `A` and `B` act like increment/decrement (A opens, B closes). You can tune velocity in the `GripperVelocityToJoint` step.
|
||||
|
||||
### Step 1: Choose the platform
|
||||
|
||||
Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`. The API is identical across platforms, only the input source differs. All examples are under `examples/` and have `phone_so100_*.py` variants.
|
||||
|
||||
Teleoperation example:
|
||||
|
||||
```36:43:examples/phone_so100_teleop.py
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
teleop_device = Phone(teleop_config)
|
||||
```
|
||||
|
||||
### Step 2: Connect and calibrate
|
||||
|
||||
When `Phone(teleop_config)` is created and `connect()` is called, calibration is prompted automatically. Hold the phone in the orientation described above, then:
|
||||
|
||||
- iOS: press and hold `B1` to capture the reference pose.
|
||||
- Android: press `Move` button on the WebXR page to capture the reference pose.
|
||||
|
||||
Why calibrate? We capture the current pose so subsequent poses are expressed in a robot aligned frame. When you again press the button to enable control, the position is recaptured to avoid drift when your phone is repositioned while it was disabled.
|
||||
|
||||
### Step 3: Run an example
|
||||
|
||||
Run on of the examples scripts to teleoperate, record a dataset, replay a dataset or evaluate a policy.
|
||||
|
||||
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
|
||||
|
||||
- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
|
||||
- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
|
||||
|
||||
You can customize mapping or safety limits by editing the processor steps shown in the examples.
|
||||
|
||||
You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
|
||||
|
||||
- Run this example to teleoperate:
|
||||
|
||||
```bash
|
||||
python examples/phone_so100_teleop.py
|
||||
```
|
||||
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
```bash
|
||||
python examples/phone_so100_record.py
|
||||
```
|
||||
|
||||
- Run this example to replay recorded episodes:
|
||||
|
||||
```bash
|
||||
python examples/phone_so100_replay.py
|
||||
```
|
||||
|
||||
- Run this example to evaluate a pretrained policy:
|
||||
|
||||
```bash
|
||||
python examples/phone_so100_eval.py
|
||||
```
|
||||
|
||||
### Important pipeline steps and options
|
||||
|
||||
- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame.
|
||||
|
||||
```44:49:examples/phone_so100_teleop.py
|
||||
RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
```
|
||||
|
||||
- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs.
|
||||
|
||||
```72:83:src/lerobot/teleoperators/phone/phone_processor.py
|
||||
# Map calibrated phone pose to robot targets (enabled gates the motion)
|
||||
act.update(
|
||||
{
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": -pos[1] if enabled else 0.0,
|
||||
"action.target_y": pos[0] if enabled else 0.0,
|
||||
"action.target_z": pos[2] if enabled else 0.0,
|
||||
"action.target_wx": rotvec[1] if enabled else 0.0,
|
||||
"action.target_wy": rotvec[0] if enabled else 0.0,
|
||||
"action.target_wz": -rotvec[2] if enabled else 0.0,
|
||||
"action.gripper": gripper,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed.
|
||||
|
||||
```56:65:examples/phone_so100_teleop.py
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
```
|
||||
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
|
||||
```61:66:examples/phone_so100_teleop.py
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
)
|
||||
```
|
||||
|
||||
- The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied.
|
||||
|
||||
```78:81:examples/phone_so100_teleop.py
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
)
|
||||
```
|
||||
|
||||
#### Different IK initial guesses
|
||||
|
||||
We use different IK initial guesses in the kinematic steps. As initial guess either the current measured joints or the previous IK solution is used.
|
||||
|
||||
- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame.
|
||||
|
||||
```71:76:examples/phone_so100_eval.py
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True, # closed loop
|
||||
)
|
||||
```
|
||||
|
||||
- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback.
|
||||
|
||||
```80:86:examples/phone_so100_replay.py
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # open loop
|
||||
)
|
||||
```
|
||||
|
||||
### Pipeline steps explained
|
||||
|
||||
- MapPhoneActionToRobotAction: converts calibrated phone pose and inputs into target deltas and a gripper command. Motion is gated by an enable signal (B1 on iOS, Move on Android).
|
||||
- AddRobotObservationAsComplimentaryData: reads current robot joints and inserts them under `complementary_data.raw_joint_positions` for FK/IK steps to use.
|
||||
- EEReferenceAndDelta: latches a reference EE pose on enable and combines it with target deltas to produce an absolute desired EE pose each frame. When disabled, it keeps sending the last commanded pose.
|
||||
- EEBoundsAndSafety: clamps the EE pose to a workspace and rate‑limits jumps for safety. Also declares `action.ee.*` features.
|
||||
- InverseKinematicsEEToJoints: turns an EE pose into joint positions with IK. `initial_guess_current_joints=True` is recommended for closed‑loop control; set `False` for open‑loop replay for stability.
|
||||
- GripperVelocityToJoint: integrates a velocity‑like gripper input into an absolute gripper position using the current measured state.
|
||||
- ForwardKinematicsJointsToEE: computes `observation.state.ee.*` from observed joints for logging and training on EE state.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- iOS not discovered: ensure HEBI Mobile I/O is open and your laptop/phone are on the same network.
|
||||
- Android URL not reachable: check local you used `https` instead of `http`, use the exact IP printed by the script and allow your browser to enter and ignore the certificate issue.
|
||||
- Motion feels inverted: adjust the sign flips in `MapPhoneActionToRobotAction` or swap axes to match your setup.
|
||||
@@ -1,148 +0,0 @@
|
||||
# Processors for Robots and Teleoperators
|
||||
|
||||
This guide shows how to build and modify processing pipelines that connect teleoperators (e.g., phone) to robots and datasets. Pipelines standardize conversions between different action/observation spaces so you can swap teleops and robots without rewriting glue code.
|
||||
|
||||
We use the Phone to SO‑100 follower examples for concreteness, but the same patterns apply to other robots.
|
||||
|
||||
**What you'll learn**
|
||||
|
||||
- Absolute vs. relative EE control: What each means, trade‑offs, and how to choose for your task.
|
||||
- Three-pipeline pattern: How to map teleop actions → dataset actions → robot commands, and robot observations → dataset observations.
|
||||
- Adapters (`to_transition` / `to_output`): How these convert raw dicts to `EnvTransition` and back to reduce boilerplate.
|
||||
- Dataset feature contracts: How steps declare features via `transform_features(...)`, and how to aggregate/merge them for recording.
|
||||
- Choosing a representation: When to store joints, absolute EE poses, or relative EE deltas—and how that affects training.
|
||||
- Pipeline customization guidance: How to swap robots/URDFs safely and tune bounds, step sizes, and options like IK initialization.
|
||||
|
||||
### Absolute vs relative EE control
|
||||
|
||||
The examples in this guide use absolute end effector (EE) poses because they are easy to reason about. In practice, relative EE deltas or joint position are often preferred as learning features.
|
||||
|
||||
You can choose what you save and learn from the teleop and robot action spaces, joints, absolute EE, or relative EE by using/implementing the right steps (and `transform_features()`) in your pipelines.
|
||||
|
||||
## Three pipelines
|
||||
|
||||
We often compose three pipelines. Depending on your setup, some can be empty if action and observation spaces already match.
|
||||
Each of these pipelines handle different conversions between different action and observation spaces. Below is a quick explanation of each pipeline.
|
||||
|
||||
1. Pipeline 1: Teleop action space → dataset action space (phone pose → EE targets)
|
||||
2. Pipeline 2: Dataset action space → robot command space (EE targets → joints)
|
||||
3. Pipeline 3: Robot observation space → dataset observation space (joints → EE pose)
|
||||
|
||||
Below is an example of the three pipelines that we use in the phone to SO-100 follower examples:
|
||||
|
||||
```69:90:examples/phone_so100_record.py
|
||||
phone_to_robot_ee_pose = RobotProcessor( # teleop -> dataset action
|
||||
steps=[MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys())),
|
||||
EEBoundsAndSafety(end_effector_bounds={"min": [-1, -1, -1], "max": [1, 1, 1]},
|
||||
max_ee_step_m=0.20, max_ee_twist_step_rad=0.50)],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
robot_ee_to_joints = RobotProcessor( # dataset action -> robot
|
||||
steps=[InverseKinematicsEEToJoints(kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True),
|
||||
GripperVelocityToJoint(motor_names=list(robot.bus.motors.keys()), speed_factor=20.0)],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot_joints_to_ee_pose = RobotProcessor( # robot obs -> dataset obs
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()))],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
```
|
||||
|
||||
## Why to_transition / to_output
|
||||
|
||||
To convert from robot/teleoperator to pipeline and back, we use the `to_transition` and `to_output` pipeline adapters.
|
||||
They standardize conversions to reduce boilerplate code, and form the bridge between the robot and teleoperators raw dicts and the pipeline’s `EnvTransition` format.
|
||||
In the phone to SO-100 follower examples we use the following adapters:
|
||||
|
||||
- `to_transition_teleop_action`: transforms the teleop action dict to a pipeline transition (puts keys under `action.*`, converts scalars/arrays to tensors, keeps objects like `Rotation` intact)
|
||||
- `to_output_robot_action`: transforms the pipeline transition to a robot action dict (extracts keys ending with `.pos`/`.vel` and strips `action.` prefix)
|
||||
- `to_transition_robot_observation`: transforms the robot observation dict to a pipeline transition (splits state vs images; stores state under `observation.state.*` and images under `observation.images.*`)
|
||||
|
||||
See `src/lerobot/processor/converters.py` for more details.
|
||||
|
||||
## Dataset feature contracts
|
||||
|
||||
Dataset features are the keys saved in the dataset. Each step can declare what its dataset features are via `transform_features(...)`. We can then aggregate features per pipeline with `aggregate_pipeline_dataset_features()` and merge multiple groups with `merge_features(...)`.
|
||||
|
||||
Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples:
|
||||
|
||||
```203:211:src/lerobot/robots/so100_follower/robot_kinematic_processor.py
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
return features
|
||||
```
|
||||
|
||||
Tip: declare features at the last step that produces them (e.g., `EEBoundsAndSafety` declares `action.ee.*`, `ForwardKinematicsJointsToEE` declares `observation.state.ee.*`).
|
||||
|
||||
Below is an example of how we aggregate and merge features in the phone to SO-100 follower examples:
|
||||
|
||||
```121:145:examples/phone_so100_record.py
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
||||
```
|
||||
|
||||
How it works:
|
||||
|
||||
- `aggregate_pipeline_dataset_features(...)`: applies `transform_features` across the pipeline and filters by patterns (images included when `use_videos=True`).
|
||||
- `merge_features(...)`: combine multiple feature dicts.
|
||||
- Recording uses `to_dataset_frame(...)` to build frames consistent with `dataset.features` before we call `add_frame(...)` to add the frame to the dataset.
|
||||
|
||||
## Guidance when customizing robot pipelines
|
||||
|
||||
You can store any of the following features as your action/observation space:
|
||||
|
||||
- Joint positions
|
||||
- Absolute EE poses
|
||||
- Relative EE deltas
|
||||
- Other features: joint velocity, etc.
|
||||
|
||||
Pick what you want to use for your policy action and observation space and configure/modify the pipelines and steps accordingly.
|
||||
|
||||
### Different robots
|
||||
|
||||
- Swap `RobotKinematics` URDF and `motor_names`. Ensure `target_frame_name` points to your gripper/wrist.
|
||||
|
||||
### Safety first
|
||||
|
||||
- When changing pipelines, start with tight bounds, implement safety steps when working with real robots.
|
||||
- Its advised to start with simulation first and then move to real robots.
|
||||
|
||||
Hope this guide helps you get started with customizing your robot pipelines, If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
|
||||
@@ -0,0 +1,288 @@
|
||||
# Reachy 2
|
||||
|
||||
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
|
||||
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
|
||||
|
||||
## Teleoperate Reachy 2
|
||||
|
||||
Currently, there are two ways to teleoperate Reachy 2:
|
||||
|
||||
- Pollen Robotics’ VR teleoperation (not included in LeRobot).
|
||||
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
|
||||
|
||||
## Reachy 2 Simulation
|
||||
|
||||
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
|
||||
|
||||
1. Install [Docker Engine](https://docs.docker.com/engine/).
|
||||
2. Run (for MuJoCo):
|
||||
|
||||
```
|
||||
docker run --rm -it \
|
||||
--name reachy \
|
||||
--privileged \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--device-cgroup-rule='c 189:* rwm' \
|
||||
--group-add audio \
|
||||
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
-e DISPLAY="$DISPLAY" \
|
||||
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
-v /dev:/dev \
|
||||
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
--entrypoint /package/launch.sh \
|
||||
pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
> docker run --rm -it \
|
||||
> --name reachy \
|
||||
> --privileged \
|
||||
> --network host \
|
||||
> --ipc host \
|
||||
> --device-cgroup-rule='c 189:* rwm' \
|
||||
> --group-add audio \
|
||||
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
> -e DISPLAY="$DISPLAY" \
|
||||
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
|
||||
> -v /dev:/dev \
|
||||
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
> --entrypoint /package/launch.sh \
|
||||
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
> start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
> ```
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- On your robot, check the **service images** meet the minimum versions:
|
||||
- **reachy2-core >= 1.7.5.2**
|
||||
- **webrtc >= 2.0.1.1**
|
||||
|
||||
Then, if you want to use VR teleoperation:
|
||||
|
||||
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
|
||||
Use version **>=v1.2.0**
|
||||
|
||||
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with Reachy 2 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[reachy2]"
|
||||
```
|
||||
|
||||
### (Optional but recommended) Install pollen_data_acquisition_server
|
||||
|
||||
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
|
||||
|
||||
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
|
||||
|
||||
In your LeRobot environment, install the server from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
|
||||
cd pollen_data_acquisition_server
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
|
||||
|
||||
## Step 1: Recording
|
||||
|
||||
### Get Reachy 2 IP address
|
||||
|
||||
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
|
||||
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
|
||||
|
||||
### Launch recording
|
||||
|
||||
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
|
||||
|
||||
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
|
||||
- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
|
||||
|
||||
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
|
||||
|
||||
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
|
||||
|
||||
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
|
||||
|
||||
```bash
|
||||
python -m pollen_data_acquisition_server.server
|
||||
```
|
||||
|
||||
Then get into the teleoperation application and choose "Data acquisition session".
|
||||
You can finally setup your session by following the screens displayed.
|
||||
|
||||
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
|
||||
|
||||
### Option 2: Using lerobot.record
|
||||
|
||||
Reachy 2 is fully supported by LeRobot’s recording features.
|
||||
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
|
||||
|
||||
**Example: start a recording without the mobile base:**
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=false \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
#### Specific Options
|
||||
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=true \
|
||||
--robot.with_l_arm=true \
|
||||
--robot.with_r_arm=true \
|
||||
--robot.with_neck=true \
|
||||
--robot.with_antennas=true \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.use_present_position=false \
|
||||
--teleop.with_mobile_base=false \
|
||||
--teleop.with_l_arm=true \
|
||||
--teleop.with_r_arm=true \
|
||||
--teleop.with_neck=true \
|
||||
--teleop.with_antennas=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
##### `--robot.use_external_commands`
|
||||
|
||||
Determine whether LeRobot robot.send_action() sends commands to the robot.
|
||||
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
|
||||
|
||||
##### `--teleop.use_present_position`
|
||||
|
||||
Determine whether the teleoperator reads the goal or present position of the robot.
|
||||
Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
|
||||
##### Use the relevant parts
|
||||
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
|
||||
````
|
||||
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
--job_name=reachy2 \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=pollen_robotics/record_test_policy
|
||||
```
|
||||
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
@@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [.
|
||||
|
||||
```bash
|
||||
cd lerobot && python -m lerobot.scripts.train \
|
||||
cd lerobot && lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=${HF_USER}/mydataset \
|
||||
--batch_size=64 \
|
||||
@@ -73,7 +73,7 @@ cd lerobot && python -m lerobot.scripts.train \
|
||||
Fine-tuning is an art. For a complete overview of the options for finetuning, run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --help
|
||||
lerobot-train --help
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
|
||||
@@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -168,7 +168,7 @@ Do the same steps for the leader arm.
|
||||
<hfoptions id="setup_motors">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -316,7 +316,7 @@ Do the same steps for the leader arm.
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||
Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
@@ -77,7 +77,7 @@ Let's break this down:
|
||||
Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change
|
||||
Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro
|
||||
We can then simply load the config values from this file using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
@@ -137,7 +137,7 @@ python -m lerobot.scripts.train \
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
@@ -148,7 +148,7 @@ python -m lerobot.scripts.train \
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
@@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
@@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai
|
||||
You could double the number of steps of the previous run with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
@@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path`
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial.
|
||||
#### Train a policy from scratch – CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
@@ -279,7 +279,7 @@ python -m lerobot.scripts.train \
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
@@ -287,7 +287,7 @@ python -m lerobot.scripts.train \
|
||||
#### Resume/continue a training run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
@@ -296,7 +296,7 @@ python -m lerobot.scripts.train \
|
||||
#### Fine-tuning
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
# storage / caches
|
||||
RAID=/raid/jade
|
||||
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
|
||||
export HF_HOME=$RAID/.cache/huggingface
|
||||
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
|
||||
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
|
||||
export WANDB_CACHE_DIR=$RAID/.cache/wandb
|
||||
export TMPDIR=$RAID/.cache/tmp
|
||||
mkdir -p $TMPDIR
|
||||
export WANDB_MODE=offline
|
||||
export HF_DATASETS_OFFLINE=1
|
||||
export HF_HUB_OFFLINE=1
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export MUJOCO_GL=egl
|
||||
export CUDA_VISIBLE_DEVICES=2
|
||||
|
||||
# CONFIGURATION
|
||||
POLICY_PATH="/raid/jade/logs/lerobot/lerobot_2_HuggingFaceVLA_libero_smolvla_lr1e-4bs32steps100000/checkpoints/100000/pretrained_model"
|
||||
POLICY_PATH="/raid/jade/models/smolvlamust"
|
||||
TASK=libero_spatial,libero_object
|
||||
ENV_TYPE="libero"
|
||||
BATCH_SIZE=1
|
||||
N_EPISODES=1
|
||||
# storage / caches
|
||||
RAID=/raid/jade
|
||||
N_ACTION_STEPS=1
|
||||
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
|
||||
export HF_HOME=$RAID/.cache/huggingface
|
||||
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
|
||||
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
|
||||
export WANDB_CACHE_DIR=$RAID/.cache/wandb
|
||||
export TMPDIR=$RAID/.cache/tmp
|
||||
mkdir -p $TMPDIR
|
||||
export WANDB_MODE=offline
|
||||
# export HF_DATASETS_OFFLINE=1
|
||||
# export HF_HUB_OFFLINE=1
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export MUJOCO_GL=egl
|
||||
export MUJOCO_GL=egl
|
||||
unset HF_HUB_OFFLINE
|
||||
# RUN EVALUATION
|
||||
python src/lerobot/scripts/eval.py \
|
||||
--policy.path="$POLICY_PATH" \
|
||||
--env.type="$ENV_TYPE" \
|
||||
--eval.batch_size="$BATCH_SIZE" \
|
||||
--eval.n_episodes="$N_EPISODES" \
|
||||
--env.multitask_eval=True \
|
||||
--env.task=$TASK \
|
||||
# python examples/evaluate_libero.py \
|
||||
# --policy_path "$POLICY_PATH" \
|
||||
# --task_suite_name "$TASK" \
|
||||
# --num_steps_wait 10 \
|
||||
# --num_trials_per_task 10 \
|
||||
# --video_out_path "data/libero/videos" \
|
||||
# --device "cuda" \
|
||||
# --seed 7
|
||||
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -12,14 +11,12 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -28,7 +25,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -46,12 +43,6 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -62,8 +53,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -1,215 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -1,106 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import to_output_robot_action
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
|
||||
# This method converts the action from the dataset to a transition for pipeline
|
||||
def action_to_transition(action: dict):
|
||||
act = {}
|
||||
|
||||
# EE pose
|
||||
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
|
||||
if k in action:
|
||||
act[f"action.{k}"] = float(action[k])
|
||||
|
||||
# Gripper: your dataset has absolute position
|
||||
if "gripper.pos" in action:
|
||||
act["action.gripper.pos"] = float(action["gripper.pos"])
|
||||
|
||||
return {
|
||||
"observation": None,
|
||||
"action": act,
|
||||
"reward": None,
|
||||
"done": False,
|
||||
"truncated": False,
|
||||
"info": {},
|
||||
"complementary_data": {},
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
phone_obs = teleop_device.get_action()
|
||||
if not phone_obs:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone to EE pose transition
|
||||
ee_transition = phone_to_robot_ee_pose(phone_obs)
|
||||
|
||||
# EE pose to Joints transition
|
||||
joint_action = robot_ee_to_joints(ee_transition)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Script to create and push a PI0OpenPI model to HuggingFace hub with proper config format."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
|
||||
|
||||
def create_and_push_model(
|
||||
repo_id: str,
|
||||
private: bool = False,
|
||||
token: str = None,
|
||||
):
|
||||
"""Create a PI0OpenPI model with proper config and push to HuggingFace hub.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace repository ID (e.g., "username/model-name")
|
||||
private: Whether to create a private repository
|
||||
token: HuggingFace API token (optional, will use cached token if not provided)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI Model Hub Upload")
|
||||
print("=" * 60)
|
||||
|
||||
# Create configuration
|
||||
print("\nCreating PI0OpenPI configuration...")
|
||||
config = PI0OpenPIConfig(
|
||||
# Model architecture
|
||||
paligemma_variant="gemma_2b",
|
||||
action_expert_variant="gemma_300m",
|
||||
pi05=False, # Use PI0 (not PI0.5)
|
||||
dtype="float32", # Use float32 for compatibility
|
||||
# Input/output dimensions
|
||||
action_dim=32, # see openpi `Pi0Config`
|
||||
state_dim=32,
|
||||
chunk_size=50,
|
||||
n_action_steps=50,
|
||||
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
||||
image_keys=(
|
||||
"observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist_0_rgb",
|
||||
),
|
||||
# Training settings
|
||||
gradient_checkpointing=False,
|
||||
compile_model=False,
|
||||
device=None, # Auto-detect
|
||||
# Tokenizer settings
|
||||
tokenizer_max_length=48, # see openpi `__post_init__`, use pi0=48 and pi05=200
|
||||
)
|
||||
|
||||
print(f" - Config type: {config.__class__.__name__}")
|
||||
print(f" - PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {config.action_expert_variant}")
|
||||
print(f" - Action dim: {config.action_dim}")
|
||||
print(f" - State dim: {config.state_dim}")
|
||||
|
||||
# Create dummy dataset stats for normalization
|
||||
print("\nCreating dataset statistics...")
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(config.state_dim),
|
||||
"std": torch.ones(config.state_dim),
|
||||
"min": torch.full((config.state_dim,), -5.0),
|
||||
"max": torch.full((config.state_dim,), 5.0),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(config.action_dim),
|
||||
"std": torch.ones(config.action_dim),
|
||||
"min": torch.full((config.action_dim,), -1.0),
|
||||
"max": torch.full((config.action_dim,), 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
# Add image stats
|
||||
for key in config.image_keys:
|
||||
dataset_stats[key] = {
|
||||
"mean": torch.tensor([0.485, 0.456, 0.406]), # TODO(pepijn): fix this, now its ImageNet mean
|
||||
"std": torch.tensor([0.229, 0.224, 0.225]), # TODO(pepijn): fix this, now its ImageNet std
|
||||
"min": torch.tensor([0.0, 0.0, 0.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0]),
|
||||
}
|
||||
|
||||
# Create the policy
|
||||
print("\nInitializing PI0OpenPI policy...")
|
||||
print(" (This may take a moment as it loads the tokenizer and initializes the model)")
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Initialize with small random weights (optional - for testing)
|
||||
# Note: In practice, you would load your trained weights here
|
||||
print("\nInitializing model weights...")
|
||||
for name, param in policy.named_parameters():
|
||||
if "weight" in name:
|
||||
if "norm" in name.lower() or "layernorm" in name.lower():
|
||||
torch.nn.init.ones_(param)
|
||||
elif len(param.shape) >= 2:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.01)
|
||||
else:
|
||||
torch.nn.init.normal_(param, mean=0.0, std=0.01)
|
||||
elif "bias" in name:
|
||||
torch.nn.init.zeros_(param)
|
||||
|
||||
print(f" - Total parameters: {sum(p.numel() for p in policy.parameters()):,}")
|
||||
print(f" - Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}")
|
||||
|
||||
# Create temporary directory for saving
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
save_path = Path(tmpdir) / "model"
|
||||
save_path.mkdir(exist_ok=True)
|
||||
|
||||
print(f"\nSaving model to temporary directory: {save_path}")
|
||||
|
||||
# Save the model using LeRobot's save_pretrained method
|
||||
# This ensures the config is saved in the correct format
|
||||
policy.save_pretrained(save_path)
|
||||
|
||||
# List saved files
|
||||
saved_files = list(save_path.glob("*"))
|
||||
print("\nSaved files:")
|
||||
for file in saved_files:
|
||||
size = file.stat().st_size
|
||||
print(f" - {file.name}: {size:,} bytes")
|
||||
|
||||
# Create or get repository
|
||||
print(f"\nCreating/accessing repository: {repo_id}")
|
||||
api = HfApi(token=token)
|
||||
|
||||
try:
|
||||
# Create repo if it doesn't exist
|
||||
create_repo(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f" ✓ Repository ready: https://huggingface.co/{repo_id}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Note: {e}")
|
||||
|
||||
# Upload to hub
|
||||
print("\nUploading to HuggingFace hub...")
|
||||
api.upload_folder(
|
||||
folder_path=str(save_path),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
commit_message="Upload PI0OpenPI model with proper LeRobot config format",
|
||||
)
|
||||
|
||||
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ Process complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Push PI0OpenPI model to HuggingFace hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="test-user/pi0-openpi-test",
|
||||
help="HuggingFace repository ID (e.g., 'username/model-name')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Create a private repository",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace API token (optional, uses cached token if not provided)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run the upload
|
||||
create_and_push_model(
|
||||
repo_id=args.repo_id,
|
||||
private=args.private,
|
||||
token=args.token,
|
||||
)
|
||||
+27
-8
@@ -29,7 +29,7 @@ version = "0.3.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||
@@ -50,7 +50,7 @@ classifiers = [
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
@@ -73,7 +73,6 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -96,7 +95,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers<=4.52.0"]
|
||||
transformers-dep = ["transformers==4.53.2"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
@@ -107,12 +106,12 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -136,13 +135,33 @@ video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
aloha = ["gym-aloha>=0.1.1"]
|
||||
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
xarm = ["gym-xarm>=0.1.1"]
|
||||
|
||||
libero = [
|
||||
"hydra-core>=1.2,<1.4",
|
||||
"numpy",
|
||||
"wandb",
|
||||
"easydict",
|
||||
"transformers",
|
||||
"opencv-python",
|
||||
"robomimic==0.2.0",
|
||||
"einops",
|
||||
"thop",
|
||||
"robosuite==1.4.0",
|
||||
"mujoco>=2.3.7,<3.0.0",
|
||||
"bddl==1.0.1",
|
||||
"matplotlib",
|
||||
"cloudpickle",
|
||||
"future",
|
||||
"gym",
|
||||
"egl_probe @ git+https://github.com/jadechoghari/egl_probe.git#egg=egl_probe",
|
||||
"libero @ git+https://github.com/jadechoghari/LIBERO.git@main#egg=libero",
|
||||
]
|
||||
# All
|
||||
all = [
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
@@ -155,7 +174,7 @@ all = [
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -261,7 +280,7 @@ default.extend-ignore-identifiers-re = [
|
||||
# paths = ["src/lerobot"]
|
||||
|
||||
# [tool.mypy]
|
||||
# python_version = "3.10"
|
||||
# python_version = "3.11"
|
||||
# warn_return_any = true
|
||||
# warn_unused_configs = true
|
||||
# ignore_missing_imports = false
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Convert video dataset to image dataset for faster training.
|
||||
This pre-extracts all frames from MP4 files to PNG images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
def convert_dataset_videos_to_images(repo_id: str, root: str | None = None):
|
||||
"""Convert all videos in a LeRobot dataset to individual image files."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
import torch
|
||||
|
||||
# Load dataset
|
||||
dataset = LeRobotDataset(repo_id, root=root, download_videos=True)
|
||||
|
||||
total_frames_processed = 0
|
||||
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
logging.info(f"Processing episode {ep_idx}/{dataset.meta.total_episodes}")
|
||||
|
||||
for vid_key in dataset.meta.video_keys:
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key)
|
||||
|
||||
if not video_path.exists():
|
||||
logging.warning(f"Video not found: {video_path}")
|
||||
continue
|
||||
|
||||
# Create image directory
|
||||
img_dir = dataset.root / f"images/chunk-{dataset.meta.get_episode_chunk(ep_idx)}/{vid_key}"
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Decode all frames from video
|
||||
# Get episode length to decode all frames
|
||||
ep_length = dataset.meta.episodes[ep_idx]["length"]
|
||||
timestamps = [i / dataset.fps for i in range(ep_length)]
|
||||
|
||||
try:
|
||||
frames = decode_video_frames(video_path, timestamps, dataset.tolerance_s, dataset.video_backend)
|
||||
|
||||
# Save each frame as PNG
|
||||
for i, frame in enumerate(frames.squeeze(0)):
|
||||
img_path = img_dir / f"episode_{ep_idx:06d}_{i:06d}.png"
|
||||
# Convert tensor to PIL and save
|
||||
import torchvision.transforms as T
|
||||
to_pil = T.ToPILImage()
|
||||
pil_frame = to_pil(frame)
|
||||
pil_frame.save(img_path)
|
||||
|
||||
total_frames_processed += len(frames.squeeze(0))
|
||||
logging.info(f" Extracted {len(frames.squeeze(0))} frames to {img_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process {video_path}: {e}")
|
||||
continue
|
||||
|
||||
logging.info(f"Conversion complete! Processed {total_frames_processed} total frames")
|
||||
logging.info(f"You can now use download_videos=False to use the extracted images")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LeRobot video dataset to images")
|
||||
parser.add_argument("repo_id", help="Dataset repo ID (e.g., 'kenmacken/record-test-2')")
|
||||
parser.add_argument("--root", help="Local root directory", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
convert_dataset_videos_to_images(args.repo_id, args.root)
|
||||
@@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator).
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
|
||||
@@ -60,7 +60,7 @@ class OpenCVCamera(Camera):
|
||||
or port changes, especially on Linux. Use the provided utility script to find
|
||||
available camera indices or paths:
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv
|
||||
lerobot-find-cameras opencv
|
||||
```
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
@@ -165,8 +165,7 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras."
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
+3
-5
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,5 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
from .configuration_reachy2_camera import Reachy2CameraConfig
|
||||
from .reachy2_camera import Reachy2Camera
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
class Reachy2CameraConfig(CameraConfig):
|
||||
"""Configuration class for Reachy 2 camera devices.
|
||||
|
||||
This class provides configuration options for Reachy 2 cameras,
|
||||
supporting both the teleop and depth cameras. It includes settings
|
||||
for resolution, frame rate, color mode, and the selection of the cameras.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
ip_address: IP address of the robot. Defaults to "localhost".
|
||||
port: Port number for the camera server. Defaults to 50065.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
image_type: str
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
self.name == "depth" and self.image_type not in ["rgb", "depth"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reachy2Camera(Camera):
|
||||
"""
|
||||
Manages Reachy 2 camera using Reachy 2 CameraManager.
|
||||
|
||||
This class provides a high-level interface to connect to, configure, and read
|
||||
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
|
||||
frame reading.
|
||||
|
||||
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
|
||||
type (e.g., "left") to be specified in the configuration.
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
overridden in the configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Reachy2CameraConfig):
|
||||
"""
|
||||
Initializes the Reachy2Camera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -51,7 +51,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Use the provided utility script to find available camera indices and default profiles:
|
||||
```bash
|
||||
python -m lerobot.find_cameras realsense
|
||||
lerobot-find-cameras realsense
|
||||
```
|
||||
|
||||
A `RealSenseCamera` instance requires a configuration object specifying the
|
||||
@@ -176,8 +176,7 @@ class RealSenseCamera(Camera):
|
||||
self.rs_profile = None
|
||||
self.rs_pipeline = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
"Run `python -m lerobot.find_cameras realsense` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras."
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
from .realsense.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
|
||||
elif cfg.type == "reachy2_camera":
|
||||
from .reachy2_camera.reachy2_camera import Reachy2Camera
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -33,8 +33,6 @@ class DatasetConfig:
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
# Percentage of dataset to use (0-100). If set, overrides episodes parameter.
|
||||
percentage: float | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,6 +53,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
@@ -71,9 +72,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
tags: list[str] | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
|
||||
@@ -24,7 +24,6 @@ class FeatureType(str, Enum):
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -21,7 +21,6 @@ OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
@@ -86,24 +87,10 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
# Handle percentage parameter
|
||||
episodes = cfg.dataset.episodes
|
||||
if cfg.dataset.percentage is not None:
|
||||
# Calculate episodes based on percentage
|
||||
total_episodes = ds_meta.total_episodes
|
||||
num_episodes_to_use = max(1, int(total_episodes * cfg.dataset.percentage / 100))
|
||||
episodes = list(range(num_episodes_to_use))
|
||||
import logging
|
||||
|
||||
logging.info(
|
||||
f"Using {cfg.dataset.percentage}% of dataset: {num_episodes_to_use}/{total_episodes} episodes"
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=episodes,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
|
||||
@@ -825,6 +825,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
else:
|
||||
episode_buffer = episode_data
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: RobotProcessor,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith("action."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("action.") :]
|
||||
hw.setdefault("action", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.state."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("observation.state.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.images."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len("observation.images.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if "action" in hw:
|
||||
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -470,50 +470,6 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def merge_features(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
|
||||
@@ -13,24 +13,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a LeRobot dataset already pushed to the Hub from codebase version 2.0 to 2.1.
|
||||
It downloads metadata from a SOURCE dataset repo, computes/validates per-episode stats, updates
|
||||
the codebase version in `info.json`, and uploads the result to a DESTINATION dataset repo.
|
||||
It will:
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the destination repo/branch and tag it with the current codebase version.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--source-repo-id=namespace/source_dataset \
|
||||
--dest-repo-id=namespace/destination_dataset \
|
||||
--branch=main
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -58,67 +54,48 @@ class SuppressWarnings:
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
source_repo_id: str,
|
||||
dest_repo_id: str,
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
# Download metadata from the source repo at v2.0
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(source_repo_id, revision=V20, force_cache_sync=True)
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
# Ensure we recompute fresh episodes stats
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
# Compute and validate stats
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
# Update codebase version in info.json
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
# Remove deprecated stats.json locally so it won't be uploaded
|
||||
if (dataset.root / STATS_PATH).is_file():
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
# Push only meta/ to destination repo
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(repo_id=dest_repo_id, private=False, repo_type="dataset", exist_ok=True)
|
||||
if branch:
|
||||
hub_api.create_branch(repo_id=dest_repo_id, branch=branch, repo_type="dataset", exist_ok=True)
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=dest_repo_id,
|
||||
folder_path=str(dataset.root),
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
allow_patterns="meta/",
|
||||
)
|
||||
|
||||
# Ensure old stats.json is deleted on destination
|
||||
if hub_api.file_exists(repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"):
|
||||
hub_api.delete_file(path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset")
|
||||
|
||||
# Tag destination with current codebase version
|
||||
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source-repo-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Source dataset repo id to download from (must be v2.0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Destination dataset repo id to upload the converted metadata to.",
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
multitask_eval: bool = False
|
||||
max_parallel_tasks: int = 5
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@@ -271,3 +273,53 @@ class HILEnvConfig(EnvConfig):
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero")
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
fps: int = 30
|
||||
episode_length: int = 520
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
init_states: bool = True
|
||||
multitask_eval: bool = True
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
+35
-12
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -29,11 +29,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
def make_env(
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> gym.vector.VectorEnv | dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
@@ -48,24 +52,43 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
dict[str, dict[int, gym.vector.VectorEnv]]: A mapping from task suite
|
||||
names to indexed vectorized environments (when multitask eval is used).
|
||||
|
||||
"""
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
raise ValueError("`n_envs` must be at least 1")
|
||||
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=cfg.camera_name,
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
multitask_eval=cfg.multitask_eval,
|
||||
)
|
||||
|
||||
package_name = f"gym_{cfg.type}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
raise ModuleNotFoundError(
|
||||
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
|
||||
) from e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
def _make_one():
|
||||
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
||||
|
||||
return env
|
||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
||||
|
||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||
return {suite_name: {0: vec}}
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---- Helpers -----------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list, tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
if not cams:
|
||||
raise ValueError("camera_name resolved to an empty list.")
|
||||
return cams
|
||||
|
||||
|
||||
def _get_suite(name: str):
|
||||
"""Instantiate a LIBERO suite by name with clear validation."""
|
||||
bench = benchmark.get_benchmark_dict()
|
||||
if name not in bench:
|
||||
raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
|
||||
suite = bench[name]()
|
||||
if not getattr(suite, "tasks", None):
|
||||
raise ValueError(f"Suite '{name}' has no tasks.")
|
||||
return suite
|
||||
|
||||
|
||||
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
|
||||
"""Validate/normalize task ids. If None → all tasks."""
|
||||
if task_ids is None:
|
||||
return list(range(total_tasks))
|
||||
ids = sorted({int(t) for t in task_ids})
|
||||
for t in ids:
|
||||
if t < 0 or t >= total_tasks:
|
||||
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
|
||||
return ids
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
suite,
|
||||
suite_name: str,
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
LiberoEnv: type, # injected to avoid forward ref issues if needed
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
|
||||
fns: list[Callable[[], LiberoEnv]] = []
|
||||
for i in range(n_envs):
|
||||
|
||||
def _mk(
|
||||
i=i,
|
||||
suite=suite,
|
||||
task_id=task_id,
|
||||
suite_name=suite_name,
|
||||
joined_cams=joined_cams,
|
||||
init_states=init_states,
|
||||
gym_kwargs=dict(gym_kwargs),
|
||||
):
|
||||
return LiberoEnv(
|
||||
task_suite=suite,
|
||||
task_id=task_id,
|
||||
task_suite_name=suite_name,
|
||||
camera_name=joined_cams,
|
||||
init_states=init_states,
|
||||
episode_index=i,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
fns.append(_mk)
|
||||
return fns
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
def create_libero_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
multitask_eval: bool = True, # kept for signature compatibility; return type is consistent regardless
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
|
||||
Returns:
|
||||
dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||
Notes:
|
||||
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
|
||||
- `task` can be a single suite or a comma-separated list of suites.
|
||||
- You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks
|
||||
|
||||
# Avoid circular import/type issues: assume LiberoEnv is defined in this module
|
||||
try:
|
||||
LiberoEnv # type: ignore[name-defined]
|
||||
except NameError:
|
||||
# If LiberoEnv is in the same file, this won't run. If it's elsewhere, import here.
|
||||
exit()
|
||||
# from .libero_env import LiberoEnv # adjust if your class lives in another module
|
||||
|
||||
camera_names = _parse_camera_names(camera_name)
|
||||
suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
|
||||
if not suite_names:
|
||||
raise ValueError("`task` must contain at least one LIBERO suite name.")
|
||||
|
||||
logger.info(
|
||||
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
|
||||
suite_names,
|
||||
n_envs,
|
||||
init_states,
|
||||
bool(multitask_eval),
|
||||
)
|
||||
if task_ids_filter is not None:
|
||||
logger.info("Restricting to task_ids=%s", task_ids_filter)
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
suite=suite,
|
||||
suite_name=suite_name,
|
||||
task_id=tid,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
LiberoEnv=LiberoEnv,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
logger.debug("Built vec env | suite=%s | task_id=%d | n_envs=%d", suite_name, tid, n_envs)
|
||||
|
||||
# return plain dicts for predictability
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
|
||||
|
||||
def quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
|
||||
Converts quaternion to axis-angle format.
|
||||
Returns a unit vector direction scaled by its angle in radians.
|
||||
|
||||
Args:
|
||||
quat (np.array): (x,y,z,w) vec4 float angles
|
||||
|
||||
Returns:
|
||||
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
def get_task_init_states(task_suite, i):
|
||||
init_states_path = os.path.join(
|
||||
get_libero_path("init_states"),
|
||||
task_suite.tasks[i].problem_folder,
|
||||
task_suite.tasks[i].init_states_file,
|
||||
)
|
||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
return init_states
|
||||
|
||||
|
||||
def get_libero_dummy_action():
|
||||
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
ACTION_DIM = 7
|
||||
|
||||
|
||||
class LiberoEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_suite,
|
||||
task_id,
|
||||
task_suite_name,
|
||||
camera_name="agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type="pixels",
|
||||
render_mode="rgb_array",
|
||||
observation_width=256,
|
||||
observation_height=256,
|
||||
visualization_width=640,
|
||||
visualization_height=480,
|
||||
init_states=True,
|
||||
episode_index=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.observation_width = observation_width
|
||||
self.observation_height = observation_height
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.init_states = init_states
|
||||
self.camera_name = camera_name.split(
|
||||
","
|
||||
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||
|
||||
# Map raw camera names to "image1" and "image2".
|
||||
# The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
|
||||
# following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
|
||||
# This ensures the policy consistently receives observations in the
|
||||
# expected format regardless of the original camera naming.
|
||||
self.camera_name_mapping = {
|
||||
"agentview_image": "image",
|
||||
"robot0_eye_in_hand_image": "image2",
|
||||
}
|
||||
|
||||
self.num_steps_wait = (
|
||||
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||
)
|
||||
self.episode_index = episode_index
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
"libero_spatial": 220, # longest training demo has 193 steps
|
||||
"libero_object": 280, # longest training demo has 254 steps
|
||||
"libero_goal": 300, # longest training demo has 270 steps
|
||||
"libero_10": 520, # longest training demo has 505 steps
|
||||
"libero_90": 400, # longest training demo has 373 steps
|
||||
}
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError(
|
||||
"The 'state' observation type is not supported in LiberoEnv. "
|
||||
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||
)
|
||||
|
||||
elif self.obs_type == "pixels":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
}
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(OBS_STATE_DIM,),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||
return image
|
||||
|
||||
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||
task = task_suite.get_task(task_id)
|
||||
self.task = task.name
|
||||
self.task_description = task.language
|
||||
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
||||
|
||||
env_args = {
|
||||
"bddl_file_name": task_bddl_file,
|
||||
"camera_heights": self.observation_height,
|
||||
"camera_widths": self.observation_width,
|
||||
}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.reset()
|
||||
if self.init_states:
|
||||
init_states = get_task_init_states(
|
||||
task_suite, task_id
|
||||
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
|
||||
init_state_id = self.episode_index # episode index
|
||||
env.set_init_state(init_states[init_state_id])
|
||||
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
image = image[::-1, ::-1] # rotate 180 degrees
|
||||
images[self.camera_name_mapping[camera_name]] = image
|
||||
state = np.concatenate(
|
||||
(
|
||||
raw_obs["robot0_eef_pos"],
|
||||
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||
raw_obs["robot0_gripper_qpos"],
|
||||
)
|
||||
)
|
||||
agent_pos = state
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError(
|
||||
"The 'state' observation type is not supported in LiberoEnv. "
|
||||
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||
)
|
||||
elif self.obs_type == "pixels":
|
||||
obs = {"pixels": images.copy()}
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
obs = {
|
||||
"pixels": images.copy(),
|
||||
"agent_pos": agent_pos,
|
||||
}
|
||||
return obs
|
||||
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
# Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action):
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
info["is_success"] = done # is_success
|
||||
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if done:
|
||||
self.reset()
|
||||
print(self.task, self.task_id, done, is_success)
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def close(self):
|
||||
self._env.close()
|
||||
|
||||
|
||||
def create_libero_envs1(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] = None,
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable = None,
|
||||
multitask_eval: bool = True,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Here n_envs is per task and equal to the number of rollouts.
|
||||
Returns:
|
||||
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
||||
"""
|
||||
print("num envs", n_envs)
|
||||
print("multitask_eval", multitask_eval)
|
||||
print("gym_kwargs", gym_kwargs)
|
||||
if gym_kwargs is None:
|
||||
gym_kwargs = {}
|
||||
|
||||
if not multitask_eval:
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||
tasks_id = list(range(len(task_suite.tasks)))
|
||||
episode_indices = [0 for i in range(len(tasks_id))]
|
||||
if len(tasks_id) == 1:
|
||||
tasks_id = [tasks_id[0] for _ in range(n_envs)]
|
||||
episode_indices = list(range(n_envs))
|
||||
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
||||
n_repeat = n_envs // len(tasks_id)
|
||||
print("n_repeat", n_repeat)
|
||||
episode_indices = []
|
||||
for _ in range(len(tasks_id)):
|
||||
episode_indices.extend(list(range(n_repeat)))
|
||||
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
|
||||
elif n_envs < len(tasks_id):
|
||||
tasks_id = tasks_id[:n_envs]
|
||||
episode_indices = list(range(n_envs))[:n_envs]
|
||||
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
|
||||
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
|
||||
assert n_envs == len(tasks_id), (
|
||||
f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
|
||||
)
|
||||
return env_cls(
|
||||
[
|
||||
lambda i=i: LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id[i],
|
||||
task_suite_name=task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
)
|
||||
for i in range(n_envs)
|
||||
]
|
||||
)
|
||||
else:
|
||||
envs = defaultdict(dict)
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task = task.split(",")
|
||||
for _task in task:
|
||||
task_suite = benchmark_dict[
|
||||
_task
|
||||
]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||
tasks_ids = list(range(len(task_suite.tasks)))
|
||||
for tasks_id in tasks_ids:
|
||||
episode_indices = list(range(n_envs))
|
||||
print(
|
||||
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
|
||||
)
|
||||
envs_list = [
|
||||
(
|
||||
lambda i=i,
|
||||
task_suite=task_suite,
|
||||
tasks_id=tasks_id,
|
||||
_task=_task,
|
||||
episode_indices=episode_indices: LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id,
|
||||
task_suite_name=_task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
)
|
||||
)
|
||||
for i in range(n_envs)
|
||||
]
|
||||
envs[_task][tasks_id] = env_cls(envs_list)
|
||||
return envs
|
||||
@@ -134,3 +134,49 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
|
||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
return observation
|
||||
|
||||
|
||||
def _close_single_env(env: Any) -> None:
|
||||
"""Try to close a single env object if it exposes .close()."""
|
||||
try:
|
||||
close_fn = getattr(env, "close", None)
|
||||
if callable(close_fn):
|
||||
close_fn()
|
||||
except Exception as exc:
|
||||
# Best-effort close: log but don't raise
|
||||
LOG.debug("Exception while closing env %s: %s", env, exc)
|
||||
|
||||
|
||||
def close_envs(env_or_collection: Any) -> None:
|
||||
"""
|
||||
Close a single env or any nested structure of envs.
|
||||
|
||||
Accepts:
|
||||
- a single env with .close()
|
||||
- a Mapping of things (e.g. dict)
|
||||
- a Sequence of things (list/tuple) but NOT str/bytes
|
||||
- nested combinations of the above
|
||||
|
||||
This is intentionally permissive and best-effort: it will swallow exceptions
|
||||
encountered while closing individual envs and continue.
|
||||
"""
|
||||
# Guard: single object with close()
|
||||
if hasattr(env_or_collection, "close") and not isinstance(env_or_collection, (Mapping, Sequence)):
|
||||
_close_single_env(env_or_collection)
|
||||
return
|
||||
|
||||
# Mapping (e.g., {suite: {task_id: vec_env}})
|
||||
if isinstance(env_or_collection, Mapping):
|
||||
for v in env_or_collection.values():
|
||||
close_envs(v)
|
||||
return
|
||||
|
||||
# Sequence (list/tuple) but skip str/bytes
|
||||
if isinstance(env_or_collection, Sequence) and not isinstance(env_or_collection, (str, bytes)):
|
||||
for v in env_or_collection:
|
||||
close_envs(v)
|
||||
return
|
||||
|
||||
# Fallback: try to close if possible
|
||||
if hasattr(env_or_collection, "close"):
|
||||
_close_single_env(env_or_collection)
|
||||
|
||||
@@ -20,7 +20,7 @@ Helper to find the camera devices available in your system.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_cameras
|
||||
lerobot-find-cameras
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
|
||||
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
|
||||
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
|
||||
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
|
||||
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
|
||||
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
|
||||
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
|
||||
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
|
||||
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
|
||||
|
||||
@@ -222,7 +222,7 @@ class MotorsBus(abc.ABC):
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python -m lerobot.find_port.py
|
||||
lerobot-find-port.py
|
||||
>>> Finding all available ports for the MotorsBus.
|
||||
>>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"]
|
||||
>>> Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
@@ -446,7 +446,7 @@ class MotorsBus(abc.ABC):
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python -m lerobot.find_port`\n"
|
||||
"\nTry running `lerobot-find-port`\n"
|
||||
) from e
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -15,19 +15,6 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .rlearn.configuration_rlearn import RLearNConfig as RLearNConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"RLearNConfig",
|
||||
]
|
||||
|
||||
@@ -35,6 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@@ -50,16 +51,27 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
@@ -125,19 +137,23 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
@@ -287,7 +303,7 @@ class ACT(nn.Module):
|
||||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig, dataset_stats=None):
|
||||
def __init__(self, config: ACTConfig):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
super().__init__()
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_act_processor(
|
||||
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
@@ -35,6 +35,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
@@ -56,6 +57,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -68,6 +70,14 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
@@ -96,6 +106,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -124,6 +137,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
@@ -139,9 +153,11 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_processor(
|
||||
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
+16
-149
@@ -14,14 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
@@ -32,18 +27,18 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
@@ -69,6 +64,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "pi0_openpi":
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
|
||||
return PI0OpenPIPolicy
|
||||
elif name == "pi05_openpi":
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
|
||||
return PI05OpenPIPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -81,10 +84,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "rlearn":
|
||||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||
|
||||
return RLearNPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -102,149 +101,24 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "pi0_openpi":
|
||||
return PI0OpenPIConfig(**kwargs)
|
||||
elif policy_type == "pi05_openpi":
|
||||
return PI05OpenPIConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "rlearn":
|
||||
return RLearNConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
|
||||
|
||||
def make_processor(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Load a pretrained processor
|
||||
# TODO(azouitine): Handle this case.
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if policy_cfg.type == "tdmpc":
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
|
||||
processors = make_tdmpc_processor(
|
||||
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "diffusion":
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
|
||||
processors = make_diffusion_processor(
|
||||
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "act":
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
|
||||
processors = make_act_processor(
|
||||
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "vqbet":
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
|
||||
processors = make_vqbet_processor(
|
||||
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi0":
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
||||
|
||||
processors = make_pi0_processor(
|
||||
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi0fast":
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
||||
|
||||
processors = make_pi0fast_processor(
|
||||
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "sac":
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
|
||||
processors = make_sac_processor(
|
||||
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "reward_classifier":
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla":
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
||||
|
||||
processors = make_smolvla_processor(
|
||||
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "rlearn":
|
||||
from lerobot.policies.rlearn.processor_rlearn import make_rlearn_processor
|
||||
|
||||
processors = make_rlearn_processor(
|
||||
cast(RLearNConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
episode_data_index: dict | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
@@ -287,6 +161,7 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
@@ -294,18 +169,12 @@ def make_policy(
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
if env_cfg is None:
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass episode_data_index for RLearN policy to calculate proper progress
|
||||
if cfg.type == "rlearn" and episode_data_index is not None:
|
||||
kwargs["episode_data_index"] = episode_data_index
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
@@ -317,7 +186,5 @@ def make_policy(
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
return policy
|
||||
|
||||
@@ -30,7 +30,7 @@ pip install -e ".[pi0]"
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
@@ -56,15 +56,18 @@ from collections import deque
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -220,17 +223,28 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -239,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -270,13 +377,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
@@ -286,6 +394,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -300,10 +410,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
@@ -370,6 +482,26 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
@@ -435,7 +567,7 @@ class PI0FlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: PI0Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
|
||||
def make_pi0_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessor(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
### ⚠️ WARNING ⚠️
|
||||
|
||||
This project requires **patching the Hugging Face `transformers` library**.
|
||||
|
||||
1. Make sure you have the exact version installed:
|
||||
|
||||
```bash
|
||||
pip show transformers
|
||||
```
|
||||
|
||||
It must be version **4.53.2**.
|
||||
|
||||
2. Apply the custom patches by copying the modified files into your environment:
|
||||
|
||||
```bash
|
||||
cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
|
||||
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
|
||||
```
|
||||
|
||||
These patches overwrite parts of `transformers` to:
|
||||
- Support the **AdaRMS optimizer**,
|
||||
- Correctly control the precision of activations,
|
||||
- Allow the KV cache to be used without updates.
|
||||
|
||||
**Important:**
|
||||
|
||||
- This permanently modifies your `transformers` installation.
|
||||
- The changes survive reinstalls unless you explicitly remove the patched files or recreate the environment.
|
||||
|
||||
To undo and restore a clean state:
|
||||
|
||||
```bash
|
||||
pip uninstall transformers
|
||||
pip install transformers==4.53.2
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| State Embedding | Uses `state_proj` layer | No state embedding |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
+4
-20
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,23 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from .configuration_pi05openpi import PI05OpenPIConfig
|
||||
from .modeling_pi05openpi import PI05OpenPIPolicy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"]
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05_openpi")
|
||||
@dataclass
|
||||
class PI05OpenPIConfig(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
discrete_state_input: bool | None = (
|
||||
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
|
||||
)
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32 # State dimension (will be padded to 32)
|
||||
max_action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10 # Number of denoising steps during inference
|
||||
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
|
||||
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
|
||||
min_period: float = 4e-3 # Min period for sinusoidal positional encoding
|
||||
max_period: float = 4.0 # Max period for sinusoidal positional encoding
|
||||
|
||||
# Image preprocessing
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW` and
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
# Image features are now handled dynamically through dataset configuration
|
||||
# No need to auto-add hardcoded image keys
|
||||
|
||||
# State and action features are also handled dynamically through dataset configuration
|
||||
# The actual dimensions come from the feature shapes, max dimensions are used for padding only
|
||||
pass
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
+173
@@ -0,0 +1,173 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
use_adarms (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use ADARMS.
|
||||
adarms_cond_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the ADARMS condition.
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_adarms: bool = False,
|
||||
adarms_cond_dim: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_adarms = use_adarms
|
||||
self.adarms_cond_dim = adarms_cond_dim
|
||||
|
||||
# Set default for adarms_cond_dim if use_adarms is True
|
||||
if self.use_adarms and self.adarms_cond_dim is None:
|
||||
self.adarms_cond_dim = self.hidden_size
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GemmaConfig"]
|
||||
@@ -0,0 +1,895 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
# Dense layer for adaptive normalization (if cond_dim is provided)
|
||||
if cond_dim is not None:
|
||||
# self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
# Initialize with zeros (matches source implementation)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
normed_inputs = self._norm(x)
|
||||
|
||||
if cond is None or self.dense is None:
|
||||
# regular RMSNorm
|
||||
# scale by learned parameter in float32 (matches source implementation)
|
||||
normed_inputs = normed_inputs * (1.0 + self.weight.float())
|
||||
return normed_inputs.to(dtype), None # return in original dtype with None gate
|
||||
|
||||
# adaptive RMSNorm (if cond is provided and dense layer exists)
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
|
||||
|
||||
# self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
|
||||
modulation = self.dense(cond)
|
||||
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
|
||||
if len(x.shape) == 3: # [batch, seq, features]
|
||||
modulation = modulation.unsqueeze(1)
|
||||
|
||||
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
|
||||
|
||||
# Apply adaptive normalization: use model weight dtype to ensure compatibility
|
||||
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
|
||||
# scale = scale.to(model_dtype)
|
||||
# shift = shift.to(model_dtype)
|
||||
# gate = gate.to(model_dtype)
|
||||
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
|
||||
|
||||
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
|
||||
|
||||
return normed_inputs.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
if self.dense is not None:
|
||||
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
||||
return repr_str
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: GemmaConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
"""
|
||||
Applies gated residual connection with optional gate parameter.
|
||||
|
||||
Args:
|
||||
x: Input tensor (residual)
|
||||
y: Output tensor to be added
|
||||
gate: Optional gate tensor to modulate the addition
|
||||
|
||||
Returns:
|
||||
x + y if gate is None, otherwise x + y * gate
|
||||
"""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_value: Cache | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Use cache if provided
|
||||
if past_key_value is not None:
|
||||
if use_cache:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
else:
|
||||
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: None
|
||||
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GemmaRMSNorm):
|
||||
if hasattr(module, "weight"):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config
|
||||
)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GemmaModel",
|
||||
"GemmaForCausalLM",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
+666
@@ -0,0 +1,666 @@
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch PaliGemmamodel."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
from .configuration_paligemma import PaliGemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor | None = None
|
||||
logits: torch.FloatTensor | None = None
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None
|
||||
hidden_states: tuple[torch.FloatTensor] | None = None
|
||||
attentions: tuple[torch.FloatTensor] | None = None
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(
|
||||
config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
is_training: bool | None = None,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
is_training = is_training if is_training is not None else self.training
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=self.dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# First unmask prefix tokens during training
|
||||
if is_training:
|
||||
if token_type_ids is None:
|
||||
raise ValueError("Token type ids must be provided during training")
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple | PaligemmaModelOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return PaligemmaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> tuple | PaliGemmaCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return PaliGemmaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
|
||||
@@ -0,0 +1,5 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
+1283
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
# π₀ (pi0)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action flow model for general robot control**.
|
||||
|
||||
---
|
||||
|
||||
### ⚠️ WARNING ⚠️
|
||||
|
||||
This project requires **patching the Hugging Face `transformers` library**.
|
||||
|
||||
1. Make sure you have the exact version installed:
|
||||
|
||||
```bash
|
||||
pip show transformers
|
||||
```
|
||||
|
||||
It must be version **4.53.2**.
|
||||
|
||||
2. Apply the custom patches by copying the modified files into your environment:
|
||||
|
||||
```bash
|
||||
cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
|
||||
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
|
||||
```
|
||||
|
||||
These patches overwrite parts of `transformers` to:
|
||||
- Support the **AdaRMS optimizer**,
|
||||
- Correctly control the precision of activations,
|
||||
- Allow the KV cache to be used without updates.
|
||||
|
||||
**Important:**
|
||||
|
||||
- This permanently modifies your `transformers` installation.
|
||||
- The changes survive reinstalls unless you explicitly remove the patched files or recreate the environment.
|
||||
|
||||
To undo and restore a clean state:
|
||||
|
||||
```bash
|
||||
pip uninstall transformers
|
||||
pip install transformers==4.53.2
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| State Embedding | Uses `state_proj` layer | No state embedding |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2024pi0visionlanguageactionflowmodel,
|
||||
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
|
||||
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
|
||||
year = {2024},
|
||||
eprint = {2410.24164},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2410.24164},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi0openpi import PI0OpenPIConfig
|
||||
from .modeling_pi0openpi import PI0OpenPIPolicy
|
||||
|
||||
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"]
|
||||
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0_openpi")
|
||||
@dataclass
|
||||
class PI0OpenPIConfig(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32 # State dimension (will be padded to 32)
|
||||
max_action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10 # Number of denoising steps during inference
|
||||
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
|
||||
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
|
||||
min_period: float = 4e-3 # Min period for sinusoidal positional encoding
|
||||
max_period: float = 4.0 # Max period for sinusoidal positional encoding
|
||||
|
||||
# Image preprocessing
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW` and
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 48 # pi0=48, see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
# Image features are now handled dynamically through dataset configuration
|
||||
# No need to auto-add hardcoded image keys
|
||||
|
||||
# State and action features are also handled dynamically through dataset configuration
|
||||
# The actual dimensions come from the feature shapes, max dimensions are used for padding only
|
||||
pass
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
+173
@@ -0,0 +1,173 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
use_adarms (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use ADARMS.
|
||||
adarms_cond_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the ADARMS condition.
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_adarms: bool = False,
|
||||
adarms_cond_dim: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_adarms = use_adarms
|
||||
self.adarms_cond_dim = adarms_cond_dim
|
||||
|
||||
# Set default for adarms_cond_dim if use_adarms is True
|
||||
if self.use_adarms and self.adarms_cond_dim is None:
|
||||
self.adarms_cond_dim = self.hidden_size
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GemmaConfig"]
|
||||
@@ -0,0 +1,895 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
# Dense layer for adaptive normalization (if cond_dim is provided)
|
||||
if cond_dim is not None:
|
||||
# self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
# Initialize with zeros (matches source implementation)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
normed_inputs = self._norm(x)
|
||||
|
||||
if cond is None or self.dense is None:
|
||||
# regular RMSNorm
|
||||
# scale by learned parameter in float32 (matches source implementation)
|
||||
normed_inputs = normed_inputs * (1.0 + self.weight.float())
|
||||
return normed_inputs.to(dtype), None # return in original dtype with None gate
|
||||
|
||||
# adaptive RMSNorm (if cond is provided and dense layer exists)
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
|
||||
|
||||
# self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
|
||||
modulation = self.dense(cond)
|
||||
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
|
||||
if len(x.shape) == 3: # [batch, seq, features]
|
||||
modulation = modulation.unsqueeze(1)
|
||||
|
||||
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
|
||||
|
||||
# Apply adaptive normalization: use model weight dtype to ensure compatibility
|
||||
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
|
||||
# scale = scale.to(model_dtype)
|
||||
# shift = shift.to(model_dtype)
|
||||
# gate = gate.to(model_dtype)
|
||||
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
|
||||
|
||||
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
|
||||
|
||||
return normed_inputs.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
if self.dense is not None:
|
||||
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
||||
return repr_str
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: GemmaConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
"""
|
||||
Applies gated residual connection with optional gate parameter.
|
||||
|
||||
Args:
|
||||
x: Input tensor (residual)
|
||||
y: Output tensor to be added
|
||||
gate: Optional gate tensor to modulate the addition
|
||||
|
||||
Returns:
|
||||
x + y if gate is None, otherwise x + y * gate
|
||||
"""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_value: Cache | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Use cache if provided
|
||||
if past_key_value is not None:
|
||||
if use_cache:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
else:
|
||||
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: None
|
||||
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GemmaRMSNorm):
|
||||
if hasattr(module, "weight"):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config
|
||||
)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GemmaModel",
|
||||
"GemmaForCausalLM",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
+666
@@ -0,0 +1,666 @@
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch PaliGemmamodel."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
from .configuration_paligemma import PaliGemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor | None = None
|
||||
logits: torch.FloatTensor | None = None
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None
|
||||
hidden_states: tuple[torch.FloatTensor] | None = None
|
||||
attentions: tuple[torch.FloatTensor] | None = None
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(
|
||||
config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
is_training: bool | None = None,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
is_training = is_training if is_training is not None else self.training
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=self.dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# First unmask prefix tokens during training
|
||||
if is_training:
|
||||
if token_type_ids is None:
|
||||
raise ValueError("Token type ids must be provided during training")
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple | PaligemmaModelOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return PaligemmaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> tuple | PaliGemmaCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return PaliGemmaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
|
||||
@@ -0,0 +1,5 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
+1283
File diff suppressed because it is too large
Load Diff
@@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
@@ -58,6 +58,7 @@ from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -145,6 +146,14 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
@@ -212,6 +221,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
@@ -224,6 +235,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -236,6 +249,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
@@ -1,128 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rlearn")
|
||||
@dataclass
|
||||
class RLearNConfig(PreTrainedConfig):
|
||||
"""Configuration for a video-language conditioned reward model (RLearN).
|
||||
|
||||
Inputs:
|
||||
- Visual frames (one or multiple cameras). Optionally a short sequence.
|
||||
- A language instruction/goal string.
|
||||
|
||||
Output:
|
||||
- Per-timestep reward logits or a single-step reward logit.
|
||||
|
||||
Notes:
|
||||
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders
|
||||
(DINOv3 for vision, SigLIP2 for language) and trains a
|
||||
lightweight temporal aggregator + head.
|
||||
"""
|
||||
|
||||
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
|
||||
vision_model_name: str = "google/siglip2-base-patch16-224"
|
||||
text_model_name: str = "google/siglip2-base-patch16-224"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
# Temporal sampling stride
|
||||
temporal_sampling_stride: int = 3 # Open x mostly has fps 10, and rewind has seq len 16, ours is 30fps so 30/10 = 3 stride lenght to have same timeframe!
|
||||
|
||||
# Model dimensions and transformer
|
||||
dim_model: int = 512
|
||||
num_layers: int = 4
|
||||
num_heads: int = 8
|
||||
ff_mult: int = 4 # Feed-forward multiplier, hidden = dim_model * ff_mult
|
||||
dropout: float = 0.05
|
||||
|
||||
# --- reward head options ---
|
||||
use_categorical_rewards: bool = False # classification over bins
|
||||
num_reward_bins: int = 25
|
||||
reward_min_value: float = 0.0 # for HL-Gauss range
|
||||
reward_max_value: float = 1.0
|
||||
use_hl_gauss_loss: bool = True # if False -> plain regression
|
||||
hl_gauss_num_bins: int = 25 # histogram resolution
|
||||
|
||||
# Inference-time subsampling and regularization
|
||||
inference_stride: int = 1 # inference_stride is an extra, second downsampling applied in forward after window sampling/rewind. Keep it at 1 to disable extra skipping
|
||||
frame_dropout_p: float = 0.10
|
||||
|
||||
# Training
|
||||
learning_rate: float = 5e-4
|
||||
weight_decay: float = 0.01
|
||||
head_lr_multiplier: float = 5.0
|
||||
logit_eps: float = 1e-4
|
||||
regularizer_warmup_steps: int = 500
|
||||
|
||||
# Performance optimizations
|
||||
use_amp: bool = False
|
||||
compile_model: bool = True
|
||||
|
||||
# ReWiND augmentation
|
||||
rewind_prob: float = 0.3 #0.8
|
||||
rewind_last3_prob: float = 0.0 #0.3
|
||||
mismatch_prob: float = 0.0 #0.2
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Required path to episodes.jsonl for episode boundaries
|
||||
episodes_jsonl_path: str | None = "meta/episodes.jsonl"
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Require at least one image feature. Language is recommended but optional (can be blank).
|
||||
if not self.image_features:
|
||||
raise ValueError(
|
||||
"You must provide at least one image feature for RLearN (e.g. 'observation.image')."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
# Request a long enough context so in-window stride sampling can be >1.
|
||||
# We ask for (max_seq_len * temporal_sampling_stride) frames ending at t=0.
|
||||
# Example: max_seq_len=16, temporal_sampling_stride=3 → 48 deltas → ~46 frames available.
|
||||
total_needed = self.max_seq_len * max(1, int(self.temporal_sampling_stride))
|
||||
return list(range(1 - total_needed, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
# Not an action chunking policy.
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
# ReWiND generates progress labels on-the-fly, doesn't need reward data
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self): # type: ignore[override]
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
|
||||
return AdamWConfig(lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
def get_scheduler_preset(self): # type: ignore[override]
|
||||
# No scheduler by default.
|
||||
return None
|
||||
@@ -1,392 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Standalone evaluation script for RLearN models.
|
||||
|
||||
This script evaluates RLearN reward models on episodes from a dataset,
|
||||
generating comparison plots between ground truth rewards and model predictions.
|
||||
|
||||
Usage:
|
||||
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
|
||||
|
||||
Example:
|
||||
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
import warnings
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import spearmanr
|
||||
from tqdm import tqdm
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# LeRobot imports
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||
|
||||
|
||||
def _to_chw_float01(img):
|
||||
"""Ensure CHW float in [0,1]."""
|
||||
if isinstance(img, np.ndarray):
|
||||
img = torch.from_numpy(img)
|
||||
# HWC -> CHW if needed
|
||||
if len(img.shape) == 3 and img.shape[-1] in (1, 3, 4):
|
||||
img = img.permute(2, 0, 1)
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.float() / 255.0
|
||||
else:
|
||||
img = img.float()
|
||||
return torch.clamp(img, 0.0, 1.0)
|
||||
|
||||
|
||||
def _get_language(frame_data):
|
||||
lang = None
|
||||
if OBS_LANGUAGE in frame_data:
|
||||
lang = frame_data[OBS_LANGUAGE]
|
||||
if isinstance(lang, list) and len(lang) > 0:
|
||||
lang = lang[0]
|
||||
elif "task" in frame_data:
|
||||
lang = frame_data["task"]
|
||||
return lang if isinstance(lang, str) else "No language provided"
|
||||
|
||||
|
||||
def _get_ground_truth_reward(frame_data):
|
||||
"""Try common keys for ground-truth reward. Return None if unavailable."""
|
||||
for key in ("reward", "rewards", "gt_reward", "progress"):
|
||||
if key in frame_data:
|
||||
r = frame_data[key]
|
||||
# unwrap single-element lists/arrays
|
||||
if isinstance(r, (list, np.ndarray)) and np.array(r).size == 1:
|
||||
r = float(np.array(r).reshape(-1)[0])
|
||||
try:
|
||||
return float(r)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def extract_episode_frames_and_gt(dataset, episode_idx):
|
||||
"""Load a full episode: frames (T, C, H, W), language (str), gt_rewards (np.ndarray or None)."""
|
||||
ep_start = dataset.episode_data_index["from"][episode_idx].item()
|
||||
ep_end = dataset.episode_data_index["to"][episode_idx].item()
|
||||
T = ep_end - ep_start
|
||||
|
||||
frames = []
|
||||
gt_rewards = []
|
||||
language = None
|
||||
|
||||
for t in range(T):
|
||||
item = dataset[ep_start + t]
|
||||
|
||||
# image(s)
|
||||
if OBS_IMAGES in item:
|
||||
img = item[OBS_IMAGES]
|
||||
elif OBS_IMAGE in item:
|
||||
img = item[OBS_IMAGE]
|
||||
else:
|
||||
# try to find an image-like key
|
||||
img_keys = [k for k in item.keys() if "image" in k.lower()]
|
||||
if not img_keys:
|
||||
continue
|
||||
img = item[img_keys[0]]
|
||||
|
||||
frames.append(_to_chw_float01(img))
|
||||
|
||||
# language once
|
||||
if language is None:
|
||||
language = _get_language(item)
|
||||
|
||||
# ground-truth reward (optional)
|
||||
r = _get_ground_truth_reward(item)
|
||||
gt_rewards.append(r)
|
||||
|
||||
if not frames:
|
||||
return None, None, None
|
||||
|
||||
frames = torch.stack(frames) # (T, C, H, W)
|
||||
|
||||
# If all GT entries are None, treat as missing
|
||||
if all(r is None for r in gt_rewards):
|
||||
gt_rewards = None
|
||||
else:
|
||||
# Replace None by forward filling
|
||||
arr = np.array([np.nan if r is None else float(r) for r in gt_rewards], dtype=float)
|
||||
# forward/back fill
|
||||
if np.isnan(arr[0]):
|
||||
first_valid = np.flatnonzero(~np.isnan(arr))
|
||||
if len(first_valid) > 0:
|
||||
arr[0] = arr[first_valid[0]]
|
||||
else:
|
||||
arr[0] = 0.0
|
||||
for i in range(1, len(arr)):
|
||||
if np.isnan(arr[i]):
|
||||
arr[i] = arr[i - 1]
|
||||
gt_rewards = arr
|
||||
|
||||
return frames, language or "No language provided", gt_rewards
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
|
||||
"""
|
||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||||
corresponding to the current frame's position in the window.
|
||||
Returns np.ndarray of shape (T,).
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
cfg = getattr(model, "config", object())
|
||||
L = int(getattr(cfg, "max_seq_len", max_seq_len))
|
||||
L = min(L, max_seq_len) # hard-cap at 16
|
||||
# Use the same temporal stride as training (skip s-1 frames, take 1)
|
||||
if temporal_stride is None:
|
||||
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
|
||||
temporal_stride = max(1, int(temporal_stride))
|
||||
|
||||
# Preprocessed tensor on device
|
||||
frames = frames.to(device)
|
||||
|
||||
windows = []
|
||||
frame_positions = [] # Track which temporal position each frame should use
|
||||
left_pad_counts = [] # Number of left-pad (OOB) frames per window
|
||||
|
||||
for i in range(T):
|
||||
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
|
||||
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
|
||||
pad_needed = sum(1 for k in idxs if k < 0)
|
||||
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
|
||||
window = frames[clamped] # (L, C, H, W)
|
||||
|
||||
# Use the last temporal position (current frame) for reading model output
|
||||
frame_pos = L - 1
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
left_pad_counts.append(pad_needed)
|
||||
|
||||
preds = np.zeros(T, dtype=float)
|
||||
|
||||
for s in range(0, T, batch_size):
|
||||
e = min(s + batch_size, T)
|
||||
batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W)
|
||||
batch_positions = frame_positions[s:e]
|
||||
|
||||
batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W)
|
||||
|
||||
# Model returns (B, L) predictions for each temporal position
|
||||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
||||
|
||||
# Apply eval-time padding rule: predictions for left-padded (OOB) frames are zero
|
||||
if values.dim() == 2 and len(left_pad_counts) >= (e - s):
|
||||
for b_idx in range(e - s):
|
||||
pad_n = left_pad_counts[s + b_idx]
|
||||
if pad_n > 0:
|
||||
values[b_idx, :pad_n] = 0.0
|
||||
|
||||
# Debug output removed - issue was identified and fixed
|
||||
|
||||
if values.dim() == 2:
|
||||
# Extract the prediction corresponding to each frame's position in its window
|
||||
batch_preds = []
|
||||
for b_idx, pos in enumerate(batch_positions):
|
||||
batch_preds.append(values[b_idx, pos].item())
|
||||
preds[s:e] = np.array(batch_preds)
|
||||
else:
|
||||
# Fallback: if model returns (B,), use as is
|
||||
preds[s:e] = values.detach().float().cpu().numpy()
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
def plot_episode_eval(episode_idx, gt, pred, language, save_path=None, show=False, title_prefix="RLearN Eval"):
|
||||
"""Plot GT vs Predicted over time. Saves PNG if save_path is provided."""
|
||||
T = len(pred)
|
||||
x = np.arange(T)
|
||||
|
||||
plt.figure(figsize=(14, 8))
|
||||
plt.plot(x, pred, linewidth=2.5, marker="o", markersize=3, label="Predicted Reward", color="blue")
|
||||
|
||||
if gt is not None:
|
||||
plt.plot(x, gt, linestyle="--", linewidth=2.5, label="Ground-Truth Reward", color="orange")
|
||||
# Correlation between GT and Pred
|
||||
corr, p = spearmanr(gt, pred)
|
||||
corr_str = f"ρ(GT, Pred) = {0.0 if np.isnan(corr) else corr:.3f} (p={0.0 if np.isnan(p) else p:.3f})"
|
||||
else:
|
||||
expected = np.linspace(0, 1, T)
|
||||
plt.plot(x, expected, linestyle="--", linewidth=2.5, label="Expected Progress (0→1)", color="orange")
|
||||
corr, p = spearmanr(x, pred)
|
||||
corr_str = f"VOC-S ρ(t, Pred) = {0.0 if np.isnan(corr) else corr:.3f} (p={0.0 if np.isnan(p) else p:.3f})"
|
||||
|
||||
plt.title(f"{title_prefix} — Episode {episode_idx}\n{language}\n{corr_str}", fontsize=14)
|
||||
plt.xlabel("Frame Index", fontsize=12)
|
||||
plt.ylabel("Reward / Progress", fontsize=12)
|
||||
plt.legend(fontsize=11)
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path, dpi=200, bbox_inches="tight")
|
||||
print(f"Saved eval image to: {save_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
|
||||
def eval_episode_sliding(
|
||||
episode_idx, dataset, model, save_dir=".", device="cuda", max_seq_len=16, batch_size=64, title_prefix="RLearN Eval"
|
||||
):
|
||||
"""End-to-end: load episode, predict with sliding 16-frame windows, and save PNG."""
|
||||
frames, language, gt = extract_episode_frames_and_gt(dataset, episode_idx)
|
||||
if frames is None:
|
||||
print(f"[Episode {episode_idx}] No frames found.")
|
||||
return None
|
||||
|
||||
model.eval()
|
||||
|
||||
pred = predict_rewards_sliding(
|
||||
model=model, frames=frames, language=language, max_seq_len=max_seq_len, batch_size=batch_size, device=device
|
||||
)
|
||||
|
||||
# Basic stats
|
||||
print(f"Episode {episode_idx}: T={len(pred)}, pred∈[{pred.min():.3f},{pred.max():.3f}]")
|
||||
if gt is not None:
|
||||
print(f"GT available: gt∈[{np.nanmin(gt):.3f},{np.nanmax(gt):.3f}]")
|
||||
|
||||
save_path = f"{save_dir}/episode_{episode_idx:04d}_eval.png"
|
||||
plot_episode_eval(
|
||||
episode_idx=episode_idx, gt=gt, pred=pred, language=language, save_path=save_path, show=False, title_prefix=title_prefix
|
||||
)
|
||||
return save_path
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation script for RLearN models."""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate RLearN model on episodes with GT vs Predicted rewards")
|
||||
parser.add_argument("--model", type=str, required=True, help="Model name/path (e.g., pepijn223/rlearn_mse5)")
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Dataset repo (e.g., pepijn223/phone_pipeline_pickup1)")
|
||||
parser.add_argument("--episodes", type=int, default=5, help="Number of episodes to evaluate")
|
||||
parser.add_argument("--output", type=str, default="./eval_results", help="Output directory for images")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
|
||||
help="Device to use",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for sliding window evaluation")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("🎯 RLearN Model Evaluation")
|
||||
print("=" * 60)
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Dataset: {args.dataset}")
|
||||
print(f"Episodes: {args.episodes}")
|
||||
print(f"Device: {args.device}")
|
||||
print(f"Output: {output_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Load dataset
|
||||
print("📁 Loading dataset...")
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.dataset,
|
||||
episodes=list(range(min(args.episodes, 50))), # Load enough episodes
|
||||
download_videos=True,
|
||||
)
|
||||
|
||||
print(f"✅ Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
print(f" Features: {list(dataset.features.keys())}")
|
||||
print(f" FPS: {dataset.fps}")
|
||||
|
||||
# Load model
|
||||
print("\n🤖 Loading model...")
|
||||
|
||||
model = RLearNPolicy.from_pretrained(args.model)
|
||||
model = model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
print(f"✅ Model loaded on {args.device}")
|
||||
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
||||
print(f" Max sequence length: {model.config.max_seq_len}")
|
||||
|
||||
# Select episodes to evaluate
|
||||
total_available = min(dataset.num_episodes, args.episodes)
|
||||
episode_indices = list(range(total_available))
|
||||
|
||||
print(f"\n📊 Evaluating {len(episode_indices)} episodes...")
|
||||
print("=" * 60)
|
||||
|
||||
# Run sliding window evaluation on each episode
|
||||
saved_paths = []
|
||||
for i, ep_idx in enumerate(episode_indices):
|
||||
print(f"\n[{i+1}/{len(episode_indices)}] Processing Episode {ep_idx}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
save_path = eval_episode_sliding(
|
||||
episode_idx=ep_idx,
|
||||
dataset=dataset,
|
||||
model=model,
|
||||
save_dir=str(output_dir),
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
title_prefix="RLearN Ground Truth vs Predicted",
|
||||
)
|
||||
|
||||
if save_path:
|
||||
saved_paths.append(save_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing episode {ep_idx}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ EVALUATION COMPLETE")
|
||||
print(f"📈 Generated {len(saved_paths)} evaluation plots")
|
||||
print(f"📁 Results saved to: {output_dir}")
|
||||
print("\nGenerated files:")
|
||||
for path in saved_paths:
|
||||
print(f" • {path}")
|
||||
|
||||
if saved_paths:
|
||||
print(f"\n💡 View the plots to compare ground truth vs predicted rewards!")
|
||||
print(f" Each plot shows the model's sliding 16-frame window predictions")
|
||||
print(f" against available ground truth rewards over the episode timeline.")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,128 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def make_rlearn_processor(
|
||||
config: RLearNConfig, dataset_stats: dict[str, dict[str, Any]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Build pre/post processors for RLearN.
|
||||
|
||||
Responsibilities moved out of the model:
|
||||
- Normalize inputs (images) using dataset stats
|
||||
- Ensure batching
|
||||
- Map complementary_data.task to observation.language when available
|
||||
- Tokenize language into observation.language.tokens / attention_mask
|
||||
- Move to/from device
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
# No renaming by default, but keep for future extensibility
|
||||
RenameProcessor(rename_map={}),
|
||||
# Move heavy normalization to GPU after transfer for better parallelism
|
||||
ToBatchProcessor(),
|
||||
RLearnLanguageFromTaskProcessor(),
|
||||
# Use SigLIP2 for tokenizer to keep vocab aligned with text tower
|
||||
TokenizerProcessor(
|
||||
tokenizer_name=config.text_model_name,
|
||||
max_length=64,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
padding_side="right",
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
# Move normalization after GPU transfer to use GPU acceleration
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rlearn_language_from_task")
|
||||
class RLearnLanguageFromTaskProcessor(ComplementaryDataProcessor):
|
||||
"""Copy complementary_data['task'] into observation['observation.language'] if present.
|
||||
|
||||
This ensures the model can consume a raw language string when tokenization is not used,
|
||||
while TokenizerProcessor can still create tokenized fields.
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: # type: ignore[override]
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if not complementary_data or self.task_key not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data.get(self.task_key)
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Normalize to list[str]
|
||||
if isinstance(task, str):
|
||||
task_list = [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
task_list = task
|
||||
else:
|
||||
return transition
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
# Do not overwrite if user already provided observation.language
|
||||
if OBS_LANGUAGE not in observation:
|
||||
observation[OBS_LANGUAGE] = task_list
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: # noqa: D401
|
||||
# Adds nothing to features; only mirrors complementary_data.task into observation
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
@@ -1,101 +0,0 @@
|
||||
## General Value/Reward Learning:
|
||||
|
||||
I want to implement a general/universal vision and language value function or reward model for robotics/video tasks. Also called a video language conditioned reward model. Integrated with already existing LeRobot code if convenient, use the LeRobot Dataset for dataset and store the reward for a frame in the lerobot frame itself.
|
||||
|
||||
Inspired by these papers:
|
||||
|
||||
- ReWiND; https://arxiv.org/pdf/2505.10911 (Most applicable and main paper I want to implement ideas from) and code: https://github.com/lucidrains/rewind-reward-pytorch
|
||||
- LIV; https://arxiv.org/pdf/2306.00958 (Most applicable and 2nd main paper I want to implement ideas from) and code https://github.com/penn-pal-lab/LI
|
||||
- VLC: Video-Language Critic: Transferable Reward Functions for Language-Conditioned Robotics: https://arxiv.org/pdf/2405.19988 (Most applicable and 3rd paper I want to implement ideas from) and code: https://github.com/minttusofia/video_language_critic
|
||||
|
||||
And these papers which are also relevant:
|
||||
|
||||
- https://www.dyna.co/dyna-1/research (Main company I want to reproduce the eventual results from)
|
||||
- vip; https://arxiv.org/pdf/2210.00030
|
||||
- uvd; https://arxiv.org/pdf/2310.08581
|
||||
- vlm in context; https://arxiv.org/pdf/2411.04549
|
||||
- https://www.youtube.com/watch?v=JfZYtpEisoM
|
||||
|
||||
Little less relevant but still similar papers:
|
||||
|
||||
- Learning Generalizable Robotic Reward Functions from “In-The-Wild” Human Videos,
|
||||
- XIRL: Cross-embodiment Inverse Reinforcement Learning,
|
||||
- Video-Language Critic: Transferable Reward https://arxiv.org/pdf/2405.19988
|
||||
- Functions for Language-Conditioned Robotics,
|
||||
- LORel, Language-Driven Representation Learning for Robotics https://sites.google.com/view/robotlorel
|
||||
- RoboCLIP: One Demonstration is Enough to Learn Robot Policies https://arxiv.org/pdf/2310.07899
|
||||
- Points2Rewards: learn first key points and then uses the keypoints to learn general value function/policy https://semrob.github.io/docs/2025_rss_semrob.github.io_paper20.pdf
|
||||
- Language-Driven Representation Learning for Robotics: https://arxiv.org/pdf/2302.12766v1
|
||||
- R3M: A Universal Visual Representation for Robot Manipulation: https://arxiv.org/pdf/2203.12601v3
|
||||
|
||||
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
|
||||
Archiutecture:
|
||||
_ inputs: video o1:T (or current o1:t), language z;
|
||||
_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
|
||||
\_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding \* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||
|
||||
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
||||
|
||||
Past images: (for example a reward method go to 3rd floor, has to know what floor it was on and what pas actions it did, can we attend or encorperate images of decision from history in one way?) Maybe via this paper: Learning Long-Context Diffusion Policies via Past-Token Prediction
|
||||
|
||||
Amount of frames needed for test/generalization: 1M frames? or ~20% of IPEC-COMMUNITY/bc_z_lerobot
|
||||
|
||||
Eval:
|
||||
Implement something like voc score , or ROC rank order correlation between reward leanredna and ev reward from sim, or use something else to do additional evaluation
|
||||
|
||||
Ideas:
|
||||
|
||||
- Incorporate training on multiple horizons: as in label same dataset for longer horizons: make a sandwich (long), put cheese on bread (medium) and even smaller horizons: go down or close gripper (small)
|
||||
- Incorporate navigation goals “walk towards the kitchen”, make sure we fix CLIP contrastive learning issue of positional text misunderstanding where model doesnnt learn difference between "horse right of cow" and "horse left of cow" “Move right” potentially train with more other data or even actionable world models such as Genie 3 (https://deepmind.google/discover/blog/genie-3-a-new-frontier-for-world-models/)
|
||||
|
||||
How to use a general reward model (use cases): - Train rl policy on it - Success detection - Do exploraion - Do task via planning and search to optimize reward - Filter out bad episodes in large datasets from imitation learning
|
||||
|
||||
Potential Datasets: (start with dataset that is most clean for this and works best with chosen way of doing evals)
|
||||
_ Epic-Kitchens-100
|
||||
_ Something-Something v. 2 Dataset https://www.qualcomm.com/developer/software/something-something-v-2-dataset
|
||||
_ Ego4D (3000 hours)
|
||||
_ Open X-Embodiment (OXE)
|
||||
\_ Agi bot world: https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha
|
||||
|
||||
- GalexiAI dataset: https://opengalaxea.github.io/G0/
|
||||
_ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
||||
_ YouCook2 dataset
|
||||
\_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
- Genie generated dataset?
|
||||
|
||||
### TODOs:
|
||||
|
||||
- Implement first architecture [x]
|
||||
- Implement processors [x]
|
||||
- Choose right loss metric(s) [x]
|
||||
- Make dataset with script that generated the dataset (IPEC-COMMUNITY/bc_z_lerobot) ready in lerobot format (and be able to visualize in dataset visualizer)
|
||||
- Annotate with ReWiND-style 0→1 progress rewards [x]
|
||||
- Visualize to check [x]
|
||||
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
|
||||
- Do first training [x]
|
||||
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
||||
- Try different losses
|
||||
- Only rewind loss [x]
|
||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||
- Test rewind (evaluate) [x]
|
||||
- benchmark siglip 2 vs this implementation forward pass, debug speed [x]
|
||||
- use siglip 2 [x]
|
||||
- Fix evaluation bug !!! []
|
||||
- Fix sample episode padding bug !!! []
|
||||
- Overfit on one episode []
|
||||
- Cleanup code? [] + enable language loss
|
||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||
- Then on 10 percent []
|
||||
- Ablation 16 sucessive frame vs 16 frame samples with stride 2 or 4 []
|
||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
|
||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
||||
- Add other datasets from OXE metioned in rewind []
|
||||
- Extend evaluation []
|
||||
- Ablation for size vision encoder, language encoder, temporal head []
|
||||
- Ablation one mlp head per frame or single mlp head []
|
||||
- Add other datasets metnioned here []
|
||||
- How can we improve spatial aware learning? solve issue of Contrastive learning and position []
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.policies.normalize import NormalizeBuffer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -44,6 +45,7 @@ class SACPolicy(
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -51,6 +53,7 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -85,7 +88,8 @@ class SACPolicy(
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -387,12 +391,28 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
if self.config.dataset_stats is not None:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = NormalizeBuffer(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = NormalizeBuffer(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
@@ -404,7 +424,9 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
@@ -412,7 +434,9 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
@@ -466,9 +490,10 @@ class SACPolicy(
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
@@ -543,10 +568,11 @@ class SACObservationEncoder(nn.Module):
|
||||
def forward(
|
||||
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs)
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
@@ -559,7 +585,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
@@ -571,17 +597,26 @@ class SACObservationEncoder(nn.Module):
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action()
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward()
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
@@ -712,6 +747,7 @@ class CriticEnsemble(nn.Module):
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
@@ -721,11 +757,13 @@ class CriticEnsemble(nn.Module):
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
@@ -737,6 +775,11 @@ class CriticEnsemble(nn.Module):
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_sac_processor(
|
||||
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
@@ -107,12 +108,22 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
@@ -236,6 +247,10 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
IdentityProcessor,
|
||||
NormalizerProcessor,
|
||||
RobotProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
NormalizerProcessor(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
|
||||
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="classifier_postprocessor"
|
||||
)
|
||||
@@ -28,7 +28,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
@@ -53,13 +53,21 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -68,6 +76,102 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -222,17 +326,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
@@ -242,6 +357,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -257,8 +389,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
@@ -266,6 +397,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -275,6 +408,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -315,11 +450,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
@@ -383,6 +518,30 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def make_smolvla_processor(
|
||||
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessor(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Adds nothing to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
@@ -36,6 +36,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -62,19 +63,26 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
):
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
@@ -129,6 +137,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -138,12 +147,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -312,9 +320,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_processor(
|
||||
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user