Files
lerobot/docs/source/multi_task_dit.mdx
Bryson Jones 2e069b1c47 Feature/add multitask diffusion transformer policy implementation (#2545)
* Add multitask diffusion transformer policy

Add multitask diffusion transformer policy

* expand the observation encoder to support differnt size encoders for vision and text

* add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs

* update readme and citations for multitask dit policy

* remove dino vision encoder and simplify text and vision encoders by removing inheritance structure

* adjust factory comment

* update docstring for multitask dit policy processor file

* simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

* add references to the modeling file comments

* merge all modules files into the main modeling file

* add torch.no_grad decorators

* split up select action return statement

* remove redundant asserts

* add tutorial to training with multi_task_dit

* fix bugs when testing on hardware

* remove environment state conditioning

* update typo in test instruction comment

* add processor tests to multitask dit tests

* move policy to top of file

* use constants for indexing into batches and remove env state references

* remove the base classes since we don't need to be able to extend

* fix nit formatting in generate actions fcn

* reformat and clean up tutorial for multitask dit policy

* add more descriptions and depth to multitask dit tutorial

* note origins of each training objective

* rename config param for multiple vision encoders

* refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit

* add multitask dit to toc for docs

* add conditional transformers import to match all other policies that use transformers lib

* add test handling for multitask dit when transformers isnt available

* skip tests without transformers

* remove cropping of images smaller than the crop size

* add kwargs arg to multitask dit constructor

* add wallx dep conflict management for multitask dit policy

* use hyphens for cleanliness in pyproject.toml

* add conflict management to pyproject toml for pi conflict for mtdp as well

* update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time

* revert fast tests edits

* update docs and readme files, fixing some typos and adding multitask dit to readme

* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

---------

Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-28 00:41:26 +01:00

341 lines
12 KiB
Plaintext

# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)