mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
122 Commits
ecd38c50d7
...
pr/2545
| Author | SHA1 | Date | |
|---|---|---|---|
| 3dbf463f3b | |||
| bc7740a15d | |||
| 03e009db78 | |||
| 0794b4ba8f | |||
| eb85d2d541 | |||
| d7537c85c5 | |||
| b2a32b8076 | |||
| cb296ee58f | |||
| 428ea89ff2 | |||
| b9cb947bd2 | |||
| 8440c561ae | |||
| 11cefed08a | |||
| 7bfedd1388 | |||
| 8c95a71c94 | |||
| 1d048c7e2b | |||
| 419305a4c2 | |||
| 23d04ca6fd | |||
| 753b996cda | |||
| 099f3ba4d7 | |||
| 3f3d08e5a8 | |||
| 9e1a67c862 | |||
| 54c38627bd | |||
| f0ef3717ca | |||
| bd8e1ccf70 | |||
| 6243b70239 | |||
| cc4377c346 | |||
| 50e13da845 | |||
| 7cca09d3da | |||
| ddfd573853 | |||
| 6d87b52abd | |||
| e10e352ca0 | |||
| d755df4461 | |||
| 0840dd6b94 | |||
| 9c4c988573 | |||
| f51e9fc1e4 | |||
| 802d73edec | |||
| 3dd51c335e | |||
| b8e36bb2f3 | |||
| 2f3b1b3db8 | |||
| 8f33a1392e | |||
| f3dc152e94 | |||
| 673e18ab46 | |||
| f4683127b2 | |||
| bdac9d7df8 | |||
| c71b30a06c | |||
| 238ea68382 | |||
| 2abb20262e | |||
| ef3b299145 | |||
| c95d151913 | |||
| 699191f81f | |||
| 6d9271940b | |||
| c24cbaacf9 | |||
| 0e7bfa5624 | |||
| fdcbd1a936 | |||
| dd39503802 | |||
| db17c08f6e | |||
| b1ab9b9c46 | |||
| 5c6714bc1b | |||
| 7581adf4cd | |||
| e9384ec834 | |||
| df52c559d3 | |||
| a011accb7f | |||
| b98ccdde3a | |||
| 473528cb14 | |||
| 713efb6427 | |||
| e268ec1ec5 | |||
| d75f3f8915 | |||
| 2cf13d9d63 | |||
| 8755bd0637 | |||
| 634e3924b8 | |||
| f5f9833540 | |||
| e2b47a142a | |||
| 2a3444a8dd | |||
| 3e5f31e0be | |||
| d653f96420 | |||
| 2b90763597 | |||
| 632c778a2b | |||
| 77dbc951cd | |||
| 5b9f98169c | |||
| 2128dece82 | |||
| b575632f4f | |||
| 8a2f5aa6cb | |||
| 25ecd16b67 | |||
| 1e049fbef7 | |||
| afe2c4d3aa | |||
| 534e143b0c | |||
| 23382c0ac5 | |||
| 4eda54c7fe | |||
| a632dd3af4 | |||
| e4a1b27fd3 | |||
| 71f359ca6e | |||
| 51dfee43f4 | |||
| 1f74982469 | |||
| 8e3a1e8945 | |||
| 43c335d0d7 | |||
| dd4ef1383f | |||
| f3823e8bcd | |||
| c398a146b3 | |||
| 9b47c5fac9 | |||
| 56dbeed89f | |||
| 67b1a9eeb1 | |||
| 86e0ee787d | |||
| ba968e84f1 | |||
| d49d3390f6 | |||
| f1ac454800 | |||
| 10cfc17705 | |||
| 5524a0d7a7 | |||
| 3a16a002f8 | |||
| 3b2a4f548c | |||
| b92dc82ddd | |||
| 103230c64c | |||
| cdacc090cd | |||
| 6f856016c5 | |||
| adabb37af6 | |||
| 55e19ff9a7 | |||
| 22714af08d | |||
| 46ebcc2f7d | |||
| a0d5a088e3 | |||
| 34499cbc1b | |||
| 8b9fada80f | |||
| ab97d5c019 | |||
| 14a7a4d7d4 |
@@ -100,11 +100,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
title: Multitask DiT Policy
|
||||
- local: walloss
|
||||
title: WALL-OSS
|
||||
title: "Policies"
|
||||
|
||||
@@ -0,0 +1,340 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The model uses:
|
||||
|
||||
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
|
||||
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
|
||||
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
|
||||
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
|
||||
|
||||
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
|
||||
VLAs, with only ~450M parameters and significantly less training.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Multitask DiT Policy has additional dependencies. Install it with:
|
||||
|
||||
```bash
|
||||
pip install lerobot[multi_task_dit]
|
||||
```
|
||||
|
||||
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
|
||||
|
||||
## Usage
|
||||
|
||||
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=multi_task_dit
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Basic Training Command
|
||||
|
||||
Here's a complete training command for training Multitask DiT on your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/multitask_dit_training \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--save_freq=500 \
|
||||
--log_freq=100 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
|
||||
|
||||
For reliable performance, start with these suggested default hyperparameters:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/mutitask_dit_training \
|
||||
--batch_size=320 \
|
||||
--steps=30000 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.horizon=32 \
|
||||
--policy.n_action_steps=24 \
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \
|
||||
--policy.num_train_timesteps=100 \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
|
||||
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
|
||||
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
|
||||
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
|
||||
- **Training Steps**: >30k steps recommended for a single task
|
||||
|
||||
### Training Configuration Parameters
|
||||
|
||||
#### Objective Selection
|
||||
|
||||
Choose between diffusion and flow matching:
|
||||
|
||||
```bash
|
||||
# Diffusion objective (default)
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
|
||||
--policy.num_train_timesteps=100 \
|
||||
--policy.num_inference_steps=10 \ # For faster inference
|
||||
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
|
||||
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
--policy.clip_sample=true \ # Clip samples during denoising
|
||||
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
|
||||
|
||||
# Flow matching objective
|
||||
--policy.objective=flow_matching \
|
||||
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
|
||||
--policy.num_integration_steps=100 \
|
||||
--policy.integration_method=euler \ # or "rk4"
|
||||
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
|
||||
```
|
||||
|
||||
#### Transformer Architecture
|
||||
|
||||
Adjust model capacity based on dataset size:
|
||||
|
||||
```bash
|
||||
# Small datasets (< 100 examples)
|
||||
--policy.num_layers=4 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
|
||||
# Medium datasets (100-5k examples) - default
|
||||
--policy.num_layers=6 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
|
||||
# Large datasets (> 5k examples)
|
||||
--policy.num_layers=8 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
```
|
||||
|
||||
**Positional Encoding Options:**
|
||||
|
||||
The model supports two positional encoding methods for action sequences:
|
||||
|
||||
```bash
|
||||
# Rotary Position Embedding (RoPE) - default, recommended
|
||||
--policy.use_rope=true \
|
||||
--policy.rope_base=10000.0 # Base frequency for RoPE
|
||||
|
||||
# Absolute positional encoding
|
||||
--policy.use_positional_encoding=true # Disables RoPE when true
|
||||
```
|
||||
|
||||
**Other Transformer Parameters:**
|
||||
|
||||
```bash
|
||||
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
|
||||
--policy.timestep_embed_dim=256 # Timestep embedding dimension
|
||||
```
|
||||
|
||||
#### Vision Encoder Configuration
|
||||
|
||||
```bash
|
||||
# Use different CLIP model for more expressivity at the cost of inference time
|
||||
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
|
||||
--policy.vision_encoder_name=openai/clip-vit-large-patch14
|
||||
|
||||
# Use separate vision encoder per camera
|
||||
# This may be useful when cameras have significantly different characteristics, but
|
||||
# be wary of increased VRAM footprint.
|
||||
--policy.use_separate_rgb_encoder_per_camera=true
|
||||
|
||||
# Image preprocessing
|
||||
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
|
||||
--policy.image_crop_shape=[224,224] \
|
||||
--policy.image_crop_is_random=true # Random during training, center at inference
|
||||
```
|
||||
|
||||
#### Text Encoder Configuration
|
||||
|
||||
```bash
|
||||
# Use different CLIP text encoder model
|
||||
# same as vision: experiment with larger or smaller models depending on the
|
||||
# complexity of your tasks and size of dataset
|
||||
--policy.text_encoder_name=openai/clip-vit-large-patch14
|
||||
```
|
||||
|
||||
#### Learning Rate Configuration
|
||||
|
||||
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
|
||||
|
||||
```bash
|
||||
--policy.optimizer_lr=2e-5 \
|
||||
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
|
||||
```
|
||||
|
||||
### Training Tuning Guidelines
|
||||
|
||||
#### 1. Flow Matching with Beta Sampling
|
||||
|
||||
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
|
||||
|
||||
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
|
||||
|
||||
Consider testing the flow-matching objective and evaluating performance differences for your task:
|
||||
|
||||
```bash
|
||||
--policy.objective=flow_matching \
|
||||
--policy.timestep_sampling_strategy=beta \
|
||||
--policy.timestep_sampling_alpha=1.5 \
|
||||
--policy.timestep_sampling_beta=1.0 \
|
||||
--policy.timestep_sampling_s=0.999
|
||||
```
|
||||
|
||||
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
|
||||
|
||||
#### 2. Number of Transformer Layers
|
||||
|
||||
Match model capacity to your dataset size:
|
||||
|
||||
- **Small datasets** (< 100 examples): Reduce to 4 layers
|
||||
- **Large datasets** (> 5k examples): Increase to 8 layers
|
||||
|
||||
#### 3. `horizon` Tuning
|
||||
|
||||
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
|
||||
|
||||
- **30 Hz frequency**: `horizon=30`
|
||||
- **10 Hz frequency**: `horizon=10`
|
||||
|
||||
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
|
||||
|
||||
#### 4. `n_action_steps` Sensitivity
|
||||
|
||||
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
|
||||
|
||||
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
|
||||
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
|
||||
|
||||
### Inference Tuning
|
||||
|
||||
For faster inference, use DDIM with fewer sampling steps:
|
||||
|
||||
```bash
|
||||
--policy.noise_scheduler_type=DDIM \
|
||||
--policy.num_inference_steps=10
|
||||
```
|
||||
|
||||
### Resuming Training
|
||||
|
||||
To resume training from a checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
|
||||
|
||||
## Common Failure Modes and Debugging
|
||||
|
||||
Training these models can be finicky. Here are common failure modes and debugging approaches:
|
||||
|
||||
### Idling / No Motion
|
||||
|
||||
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
|
||||
|
||||
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
|
||||
|
||||
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
|
||||
|
||||
**Debugging tips:**
|
||||
|
||||
- Increase dataset size (double until you get to over 300 examples)
|
||||
- Train for longer, up to 100k steps, even when the loss flatlines
|
||||
- Check if the model is receiving proper language instructions or increase diversity of instruction
|
||||
|
||||
### Executing the Wrong Task
|
||||
|
||||
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
|
||||
|
||||
**Potential causes:**
|
||||
|
||||
- Language instruction ambiguity
|
||||
- Insufficient task-specific training data
|
||||
- Model confusion between similar tasks in the multitask dataset
|
||||
|
||||
**Debugging tips:**
|
||||
|
||||
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
|
||||
- Check task distribution in your training dataset and add weighting to the failing/ignored task
|
||||
- Consider task-specific fine-tuning
|
||||
|
||||
### Training Instability
|
||||
|
||||
If training loss is unstable or diverging:
|
||||
|
||||
- Try adjusting learning rate between `1e-5` and `3e-4`
|
||||
- Increase batch size if possible
|
||||
- Check that your dataset normalization is correct
|
||||
- Verify image preprocessing is working correctly
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### GPU Requirements
|
||||
|
||||
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
|
||||
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
|
||||
|
||||
### Batch Size Recommendations
|
||||
|
||||
- **Minimum**: 64 (less than this may result in unstable training)
|
||||
- **Recommended**: 256-320 (best performance, requires larger GPU)
|
||||
|
||||
## Example: Training on Custom Dataset
|
||||
|
||||
Here's a complete example training on a custom dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/mutitask_dit_training \
|
||||
--batch_size=320 \
|
||||
--steps=30000 \
|
||||
--save_freq=1000 \
|
||||
--log_freq=100 \
|
||||
--eval_freq=1000 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.horizon=32 \
|
||||
--policy.n_action_steps=24 \
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \
|
||||
--policy.num_layers=6 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
|
||||
--policy.image_resize_shape=[320,240] \
|
||||
--policy.image_crop_shape=[224,224] \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=multitask_dit
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
For more details on the technical implementation and architecture, see:
|
||||
|
||||
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
|
||||
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
|
||||
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
|
||||
@@ -0,0 +1,37 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -144,6 +144,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
@@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -66,8 +67,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -86,6 +86,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "multi_task_dit":
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
|
||||
return MultiTaskDiTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
@@ -146,8 +150,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -162,6 +166,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "multi_task_dit":
|
||||
return MultiTaskDiTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
@@ -308,6 +314,16 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MultiTaskDiTConfig):
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_multi_task_dit_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
|
||||
|
||||
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]
|
||||
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
@dataclass
|
||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
|
||||
A transformer-based policy that supports both diffusion and flow matching objectives
|
||||
for multi-task robot learning with text and vision conditioning.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 2 # Number of observation steps for temporal context
|
||||
horizon: int = 32 # Number of action steps to predict
|
||||
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||
|
||||
# Objective Selection
|
||||
objective: str = "diffusion" # "diffusion" or "flow_matching"
|
||||
|
||||
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # Number of diffusion timesteps
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
|
||||
beta_start: float = 0.0001 # Starting noise level
|
||||
beta_end: float = 0.02 # Ending noise level
|
||||
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
clip_sample: bool = True # Clip samples during denoising
|
||||
clip_sample_range: float = 1.0 # Clipping range [-x, x]
|
||||
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
|
||||
|
||||
# --- Flow Matching-specific (used when objective="flow_matching") ---
|
||||
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
|
||||
|
||||
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
|
||||
|
||||
# Transformer Architecture
|
||||
hidden_dim: int = 512 # Transformer hidden dimension
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = False # Use absolute positional encoding
|
||||
timestep_embed_dim: int = 256 # Timestep embedding dimension
|
||||
use_rope: bool = True # Use Rotary Position Embedding
|
||||
rope_base: float = 10000.0 # RoPE base frequency
|
||||
|
||||
# Vision Encoder (CLIP)
|
||||
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||
image_crop_is_random: bool = True # Random crop during training, center at inference
|
||||
|
||||
# Text Encoder (CLIP)
|
||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
|
||||
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
|
||||
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
|
||||
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Training/Optimizer
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 0
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Auto-calculated
|
||||
drop_n_last_frames: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.drop_n_last_frames is None:
|
||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Validate configuration parameters."""
|
||||
# Objective validation
|
||||
if self.objective not in ["diffusion", "flow_matching"]:
|
||||
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
|
||||
|
||||
# Transformer validation
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
# Vision encoder validation
|
||||
if "clip" not in self.vision_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
|
||||
)
|
||||
if (
|
||||
self.image_resize_shape
|
||||
and self.image_crop_shape
|
||||
and (
|
||||
self.image_crop_shape[0] > self.image_resize_shape[0]
|
||||
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||
)
|
||||
):
|
||||
logging.warning(
|
||||
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
self.image_resize_shape,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
|
||||
# Text encoder validation
|
||||
if "clip" not in self.text_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
|
||||
)
|
||||
|
||||
# Objective-specific validation
|
||||
if self.objective == "diffusion":
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
|
||||
|
||||
elif self.objective == "flow_matching":
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(
|
||||
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
|
||||
)
|
||||
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
|
||||
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
|
||||
if self.timestep_sampling_strategy == "beta":
|
||||
if not (0.0 < self.timestep_sampling_s <= 1.0):
|
||||
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
|
||||
if self.timestep_sampling_alpha <= 0:
|
||||
raise ValueError("timestep_sampling_alpha must be positive")
|
||||
if self.timestep_sampling_beta <= 0:
|
||||
raise ValueError("timestep_sampling_beta must be positive")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# If the configured crop doesn't fit, disable cropping instead of erroring.
|
||||
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
|
||||
if self.image_crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
# image_ft.shape is (C, H, W)
|
||||
effective_h, effective_w = (
|
||||
self.image_resize_shape
|
||||
if self.image_resize_shape is not None
|
||||
else (image_ft.shape[1], image_ft.shape[2])
|
||||
)
|
||||
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
|
||||
logging.warning(
|
||||
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
effective_h,
|
||||
effective_w,
|
||||
key,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
break
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
first_key, first_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_ft.shape:
|
||||
raise ValueError(
|
||||
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_diffusion(self) -> bool:
|
||||
return self.objective == "diffusion"
|
||||
|
||||
@property
|
||||
def is_flow_matching(self) -> bool:
|
||||
return self.objective == "flow_matching"
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,803 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-Task Diffusion Transformer (DiT) Policy
|
||||
|
||||
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
|
||||
Supports both diffusion and flow matching objectives for action generation.
|
||||
|
||||
References:
|
||||
- https://arxiv.org/abs/2507.05331
|
||||
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
|
||||
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import CLIPTextModel, CLIPVisionModel
|
||||
else:
|
||||
CLIPTextModel = None
|
||||
CLIPVisionModel = None
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
# -- Policy --
|
||||
|
||||
|
||||
class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
config_class = MultiTaskDiTConfig
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.observation_encoder = ObservationEncoder(config)
|
||||
conditioning_dim = self.observation_encoder.conditioning_dim
|
||||
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
|
||||
|
||||
action_dim = config.action_feature.shape[0]
|
||||
horizon = config.horizon
|
||||
|
||||
if config.is_diffusion:
|
||||
self.objective = DiffusionObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
elif config.is_flow_matching:
|
||||
self.objective = FlowMatchingObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported objective: {config.objective}")
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> list:
|
||||
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
|
||||
non_vision_params = []
|
||||
vision_encoder_params = []
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if "observation_encoder.vision_encoder" in name:
|
||||
vision_encoder_params.append(param)
|
||||
else:
|
||||
non_vision_params.append(param)
|
||||
|
||||
return [
|
||||
{"params": non_vision_params},
|
||||
{
|
||||
"params": vision_encoder_params,
|
||||
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
|
||||
},
|
||||
]
|
||||
|
||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
|
||||
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations"""
|
||||
self.eval()
|
||||
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Prepare batch by stacking image features if needed."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations"""
|
||||
if ACTION in batch:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training"""
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||
|
||||
return loss, None
|
||||
|
||||
|
||||
# -- Observation Encoders --
|
||||
|
||||
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||
self.num_non_spatial_tokens = 1
|
||||
self.embed_dim = self.model.config.hidden_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token."""
|
||||
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||
cls_token = outputs.last_hidden_state[:, 0]
|
||||
b, embed_dim = cls_token.shape
|
||||
return cls_token.reshape(b, embed_dim, 1, 1)
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer.
|
||||
|
||||
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
|
||||
pipeline to see how the tokenization is handled.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.projection_dim = projection_dim
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Encode pre-tokenized text to feature vectors."""
|
||||
# Ensure inputs are on the same device as the model
|
||||
device = next(self.parameters()).device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
clip_features = outputs.pooler_output
|
||||
|
||||
return self.projection(clip_features)
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""Handles all observation processing for the conditioning vector."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._setup_preprocessing(config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if config.use_separate_rgb_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
self.vision_encoders = None
|
||||
self.camera_names = []
|
||||
self.num_cameras = 0
|
||||
|
||||
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
|
||||
self.robot_state_dim = config.robot_state_feature.shape[0]
|
||||
else:
|
||||
self.robot_state_dim = 0
|
||||
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
def _apply_preprocessing(self, images: Tensor) -> Tensor:
|
||||
if self.do_resize:
|
||||
images = self.resize(images)
|
||||
if self.do_crop:
|
||||
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, config):
|
||||
if config.image_resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=config.image_resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
|
||||
if config.image_crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||
if config.image_crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
total_dim = 0
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
c, h, w = feature_map_shape
|
||||
spatial_feature_dim = c * h * w
|
||||
total_dim += spatial_feature_dim * self.num_cameras
|
||||
|
||||
total_dim += self.robot_state_dim
|
||||
total_dim += self.text_dim
|
||||
|
||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||
|
||||
def encode(self, batch: dict) -> Tensor:
|
||||
"""Encode observations to vector format."""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
conditioning_feats = []
|
||||
|
||||
conditioning_feats.append(batch[OBS_STATE])
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
images = batch[OBS_IMAGES]
|
||||
|
||||
if len(images.shape) == 5:
|
||||
images = images.unsqueeze(1)
|
||||
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
camera_features = []
|
||||
for cam_idx in range(self.num_cameras):
|
||||
cam_images = images[:, :, cam_idx]
|
||||
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
|
||||
cam_images_flat = self._apply_preprocessing(cam_images_flat)
|
||||
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
|
||||
cam_visual_features = cam_features.flatten(start_dim=1)
|
||||
cam_features_reshaped = einops.rearrange(
|
||||
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
camera_features.append(cam_features_reshaped)
|
||||
img_features = torch.cat(camera_features, dim=-1)
|
||||
conditioning_feats.append(img_features)
|
||||
else:
|
||||
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
|
||||
images_flat = self._apply_preprocessing(images_flat)
|
||||
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
|
||||
img_features = einops.rearrange(
|
||||
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
|
||||
)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
|
||||
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
|
||||
|
||||
text_features = self.text_encoder(input_ids, attention_mask)
|
||||
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1)
|
||||
return combined_features.flatten(start_dim=1)
|
||||
|
||||
|
||||
# -- Transformer Components --
|
||||
|
||||
|
||||
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
||||
"""Modulate input with shift and scale for AdaLN-Zero."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embeddings for timesteps."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE) for transformers."""
|
||||
|
||||
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
|
||||
super().__init__()
|
||||
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._precompute_cache(max_seq_len)
|
||||
|
||||
def _precompute_cache(self, seq_len: int):
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def _rotate_half(self, x: Tensor) -> Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
|
||||
seq_len = q.shape[2]
|
||||
if seq_len > self.max_seq_len:
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
|
||||
|
||||
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
|
||||
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
|
||||
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
|
||||
return q_rotated, k_rotated
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, T, _ = x.shape # noqa: N806
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
|
||||
)
|
||||
|
||||
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""DiT-style transformer block with AdaLN-Zero."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 128,
|
||||
num_heads: int = 4,
|
||||
num_features: int = 128,
|
||||
dropout: float = 0.0,
|
||||
use_rope: bool = False,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
|
||||
if use_rope:
|
||||
self.attn = RoPEAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
max_seq_len=max_seq_len,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
else:
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size * 4),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size * 4, hidden_size),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, features: Tensor) -> Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
features
|
||||
).chunk(6, dim=1)
|
||||
|
||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||
|
||||
if self.use_rope:
|
||||
attn_out = self.attn(attn_input)
|
||||
else:
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||
|
||||
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
|
||||
mlp_out = self.mlp(mlp_input)
|
||||
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
"""Transformer-based diffusion noise prediction model."""
|
||||
|
||||
def __init__(self, config, conditioning_dim: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.conditioning_dim = conditioning_dim
|
||||
|
||||
self.action_dim = config.action_feature.shape[0]
|
||||
self.horizon = config.horizon
|
||||
self.hidden_size = config.hidden_dim
|
||||
self.num_layers = config.num_layers
|
||||
self.num_heads = config.num_heads
|
||||
self.dropout = config.dropout
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
self.timestep_embed_dim = config.timestep_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.cond_dim = self.timestep_embed_dim + conditioning_dim
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if config.use_positional_encoding:
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
use_rope=self.use_rope,
|
||||
max_seq_len=self.horizon,
|
||||
rope_base=config.rope_base,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
|
||||
_, seq_len, _ = x.shape
|
||||
|
||||
timestep_features = self.time_mlp(timestep)
|
||||
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
|
||||
|
||||
hidden_seq = self.input_proj(x)
|
||||
|
||||
if self.pos_embedding is not None:
|
||||
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_seq = block(hidden_seq, cond_features)
|
||||
|
||||
return self.output_proj(hidden_seq)
|
||||
|
||||
|
||||
# -- Objectives --
|
||||
|
||||
|
||||
class DiffusionObjective(nn.Module):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
scheduler_kwargs = {
|
||||
"num_train_timesteps": config.num_train_timesteps,
|
||||
"beta_start": config.beta_start,
|
||||
"beta_end": config.beta_end,
|
||||
"beta_schedule": config.beta_schedule,
|
||||
"clip_sample": config.clip_sample,
|
||||
"clip_sample_range": config.clip_sample_range,
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
|
||||
|
||||
self.num_inference_steps = (
|
||||
config.num_inference_steps
|
||||
if config.num_inference_steps is not None
|
||||
else self.noise_scheduler.config.num_train_timesteps
|
||||
)
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
clean_actions = batch[ACTION]
|
||||
noise = torch.randn_like(clean_actions)
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(clean_actions.shape[0],),
|
||||
device=clean_actions.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
|
||||
|
||||
prediction_type = self.noise_scheduler.config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif prediction_type == "sample":
|
||||
target = clean_actions
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {prediction_type}")
|
||||
|
||||
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.horizon, self.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
model_output = model(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
conditioning_vec=conditioning_vec,
|
||||
)
|
||||
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlowMatchingObjective(nn.Module):
|
||||
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
if self.config.timestep_sampling_strategy == "uniform":
|
||||
return torch.rand(batch_size, device=device)
|
||||
elif self.config.timestep_sampling_strategy == "beta":
|
||||
beta_dist = torch.distributions.Beta(
|
||||
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
|
||||
)
|
||||
u = beta_dist.sample((batch_size,)).to(device)
|
||||
return self.config.timestep_sampling_s * (1.0 - u)
|
||||
else:
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
data = batch[ACTION]
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
|
||||
noise = torch.randn_like(data)
|
||||
t = self._sample_timesteps(batch_size, device)
|
||||
t_expanded = t.view(-1, 1, 1)
|
||||
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
|
||||
|
||||
target_velocity = data - (1 - self.config.sigma_min) * noise
|
||||
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
|
||||
|
||||
num_steps = self.config.num_integration_steps
|
||||
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
|
||||
|
||||
if self.config.integration_method == "euler":
|
||||
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
|
||||
elif self.config.integration_method == "rk4":
|
||||
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
|
||||
else:
|
||||
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
|
||||
|
||||
return x
|
||||
|
||||
def _euler_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
for i in range(len(time_grid) - 1):
|
||||
t_scalar = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
|
||||
with torch.no_grad():
|
||||
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
|
||||
x = x + dt * velocity
|
||||
return x
|
||||
|
||||
def _rk4_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
|
||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||
with torch.no_grad():
|
||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
t = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
|
||||
k1 = dynamics(x, t)
|
||||
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
|
||||
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
|
||||
k4 = dynamics(x + dt * k3, t + dt)
|
||||
|
||||
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_multi_task_dit_pre_post_processors(
|
||||
config: MultiTaskDiTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features.
|
||||
2. Adding a batch dimension.
|
||||
3. Tokenizing the language task description (if present).
|
||||
4. Moving the data to the specified device.
|
||||
5. Normalizing the input and output features based on dataset statistics.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output features to their original scale.
|
||||
2. Moving the data to the CPU.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Multi-Task DiT policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_encoder_name,
|
||||
padding=config.tokenizer_padding,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
max_length=config.tokenizer_max_length,
|
||||
truncation=config.tokenizer_truncation,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: E402
|
||||
|
||||
"""Test script for Multi-Task DiT policy.
|
||||
|
||||
To run tests locally:
|
||||
python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local transformers installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 17
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def create_train_batch(
|
||||
batch_size: int = 2,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Create a training batch with visual input and text."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
|
||||
ACTION: torch.randn(batch_size, horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(
|
||||
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
|
||||
) -> dict:
|
||||
"""Create observation batch for inference for a single timestep."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
|
||||
"task": ["pick up the red cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
n_action_steps: int = 8,
|
||||
with_visual: bool = True,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> MultiTaskDiTConfig:
|
||||
"""Create a MultiTaskDiT config for testing.
|
||||
|
||||
Args:
|
||||
state_dim: Dimension of state observations
|
||||
action_dim: Dimension of actions
|
||||
n_obs_steps: Number of observation steps
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
with_visual: Whether to include visual input (default: True)
|
||||
height: Image height (only used if with_visual=True)
|
||||
width: Image width (only used if with_visual=True)
|
||||
"""
|
||||
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
|
||||
|
||||
if with_visual:
|
||||
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, height, width)
|
||||
)
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use smaller model for faster tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test forward pass (training mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
assert loss.shape == ()
|
||||
|
||||
# Test backward pass
|
||||
loss.backward()
|
||||
|
||||
|
||||
def test_multi_task_dit_pre_post_processors():
|
||||
"""Test pre and post processors for Multi-Task DiT policy."""
|
||||
state_dim = 10
|
||||
action_dim = 8
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=8,
|
||||
)
|
||||
config.device = "cpu"
|
||||
|
||||
# Set normalization mode to match the stats we're providing
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Create dataset stats for normalization
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(state_dim),
|
||||
"std": torch.ones(state_dim),
|
||||
},
|
||||
"action": {
|
||||
"min": torch.full((action_dim,), -1.0),
|
||||
"max": torch.ones(action_dim),
|
||||
},
|
||||
}
|
||||
|
||||
# Create processors
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
# Test preprocessor with sample data
|
||||
batch = {
|
||||
"observation.state": torch.randn(state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||
ACTION: torch.randn(action_dim),
|
||||
"task": "pick up the cube",
|
||||
}
|
||||
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||
# Check that task text was tokenized
|
||||
assert OBS_LANGUAGE_TOKENS in processed_batch
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
|
||||
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
|
||||
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
|
||||
|
||||
# Check that data is on correct device
|
||||
assert processed_batch["observation.state"].device.type == "cpu"
|
||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu"
|
||||
assert processed_batch[ACTION].device.type == "cpu"
|
||||
|
||||
# Test postprocessor with sample action (PolicyAction is just a torch.Tensor)
|
||||
action = torch.randn(1, action_dim)
|
||||
processed_action = postprocessor(action)
|
||||
|
||||
# Check that action is unnormalized and on CPU
|
||||
assert processed_action.shape == (1, action_dim)
|
||||
assert processed_action.device.type == "cpu"
|
||||
|
||||
|
||||
def test_multi_task_dit_pre_post_processors_normalization():
|
||||
"""Test that normalization and unnormalization work correctly with simple sanity check numbers."""
|
||||
state_dim = 3
|
||||
action_dim = 2
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=2,
|
||||
horizon=16,
|
||||
n_action_steps=8,
|
||||
)
|
||||
config.device = "cpu"
|
||||
|
||||
# Set normalization mode to match the stats we're providing
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Use simple stats that will actually transform the values
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.full((state_dim,), 5.0),
|
||||
"std": torch.full((state_dim,), 2.0),
|
||||
},
|
||||
"action": {
|
||||
"min": torch.zeros(action_dim),
|
||||
"max": torch.full((action_dim,), 2.0),
|
||||
},
|
||||
}
|
||||
|
||||
# Create processors
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
# Use simple input values
|
||||
input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0]
|
||||
input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0]
|
||||
|
||||
batch = {
|
||||
"observation.state": input_state,
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
|
||||
ACTION: input_action,
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
# Process through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# State normalization: (x - mean) / std
|
||||
expected_normalized_state = torch.tensor([1.0, 0.0, -1.0])
|
||||
assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5)
|
||||
|
||||
# Action normalization: (x - min) / (max - min) * 2 - 1
|
||||
expected_normalized_action = torch.tensor([0.0, 1.0])
|
||||
assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5)
|
||||
|
||||
# Test unnormalization: should recover original values
|
||||
normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension
|
||||
unnormalized_action = postprocessor(normalized_action_tensor)
|
||||
|
||||
# Should recover original action values
|
||||
assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test select_action (inference mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.reset() # Reset queues before inference
|
||||
|
||||
# Create processors - use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_diffusion_objective():
|
||||
"""Test policy with diffusion objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use diffusion objective
|
||||
objective="diffusion",
|
||||
noise_scheduler_type="DDPM",
|
||||
num_train_timesteps=100,
|
||||
num_inference_steps=10,
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
# Use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_flow_matching_objective():
|
||||
"""Test policy with flow matching objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use flow matching objective
|
||||
objective="flow_matching",
|
||||
sigma_min=0.0,
|
||||
num_integration_steps=10, # Fewer steps for faster tests
|
||||
integration_method="euler",
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Process batch through preprocessor to tokenize task text
|
||||
processed_batch = preprocessor(batch)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
# Use IDENTITY normalization when no stats provided
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
selected_action = policy.select_action(processed_obs)
|
||||
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
|
||||
processed_action = postprocessor(selected_action)
|
||||
assert processed_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded correctly."""
|
||||
root = tmp_path / "test_multi_task_dit_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
# Get device before saving
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||
|
||||
# Explicitly move loaded_policy to the same device
|
||||
loaded_policy.to(device)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Move batch to the same device as the policy
|
||||
for key in processed_batch:
|
||||
if isinstance(processed_batch[key], torch.Tensor):
|
||||
processed_batch[key] = processed_batch[key].to(device)
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Process observation through preprocessor
|
||||
processed_obs = preprocessor(observation_batch)
|
||||
actions = policy.select_action(processed_obs)
|
||||
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Collect policy values after loading
|
||||
loaded_loss, _ = loaded_policy.forward(processed_batch)
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
processed_obs = preprocessor(loaded_observation_batch)
|
||||
loaded_actions = loaded_policy.select_action(processed_obs)
|
||||
|
||||
# Compare state dicts
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
assert torch.allclose(loss, loaded_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_get_optim_params():
|
||||
"""Test that the policy returns correct optimizer parameter groups."""
|
||||
config = create_config(
|
||||
state_dim=10,
|
||||
action_dim=10,
|
||||
n_obs_steps=2,
|
||||
horizon=16,
|
||||
n_action_steps=8,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
param_groups = policy.get_optim_params()
|
||||
|
||||
# Should have 2 parameter groups: non-vision and vision encoder
|
||||
assert len(param_groups) == 2
|
||||
|
||||
# First group is non-vision params (no lr specified, will use default)
|
||||
assert "params" in param_groups[0]
|
||||
assert len(param_groups[0]["params"]) > 0
|
||||
|
||||
# Second group is vision encoder params with different lr
|
||||
assert "params" in param_groups[1]
|
||||
assert "lr" in param_groups[1]
|
||||
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
|
||||
assert param_groups[1]["lr"] == expected_lr
|
||||
Reference in New Issue
Block a user