make default params more aligned with paper and pretrained models

- adding possibility of freezing qwen backbone and world model
- added tests for weight loading
This commit is contained in:
Maximellerbach
2026-05-18 12:55:55 +02:00
parent cee80daa88
commit e9c171ead0
4 changed files with 128 additions and 73 deletions
+19 -29
View File
@@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO")
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- |
| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant |
| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder |
| `chunk_size` | 16 | Number of actions predicted per inference call |
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 16 | Video clip length fed to the world model |
| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) |
| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim |
| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) |
| `action_num_layers` | 12 | Number of DiT transformer blocks |
| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence |
| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head |
| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) |
| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence |
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `predictor_depth` | 6 | Number of transformer blocks in the video predictor |
| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) |
| Parameter | Default | Description |
| ------------------------- | ------- | -------------------------------------------------------------- |
| `chunk_size` | 16 | Number of actions predicted per inference call |
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 16 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
---
@@ -120,10 +109,19 @@ lerobot-train \
```bash
lerobot-train \
policy.path=lerobot/VLA-JEPA-LIBERO \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset
```
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset \
policy.freeze_qwen=true
```
### Reproducing the LIBERO results
**Training on LIBERO:**
@@ -138,14 +136,6 @@ TODO(Maxime):
lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=lerobot/libero_10 \
policy.chunk_size=7 \
policy.n_action_steps=7 \
policy.future_action_window_size=6 \
policy.num_video_frames=8 \
policy.num_action_tokens_per_timestep=8 \
policy.num_embodied_action_tokens_per_instruction=32 \
policy.action_num_layers=16 \
policy.predictor_depth=12 \
training.num_steps=50000 \
env.type=libero \
env.task=libero_10