mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
docs: update X-VLA training strategies/commands (#2611)
This commit is contained in:
+40
-82
@@ -24,7 +24,7 @@ Built from pure Transformer encoders, X-VLA scales naturally with model size and
|
|||||||
<img
|
<img
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||||
alt="XVLA Architecture 2"
|
alt="XVLA Architecture 2"
|
||||||
style="width: 32%; max-width: 450px; height: auto;"
|
style="width: 60%; height: auto;"
|
||||||
/>
|
/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ Adapted for Google Robot platforms.
|
|||||||
|
|
||||||
### Recommended Training Configuration
|
### Recommended Training Configuration
|
||||||
|
|
||||||
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
|
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
@@ -129,25 +129,26 @@ lerobot-train \
|
|||||||
--job_name=xvla_training \
|
--job_name=xvla_training \
|
||||||
--policy.path="lerobot/xvla-base" \
|
--policy.path="lerobot/xvla-base" \
|
||||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
--steps=3000 \
|
--steps=3000 \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--policy.freeze_vision_encoder=True \
|
--policy.freeze_vision_encoder=false \
|
||||||
--policy.freeze_language_encoder=True \
|
--policy.freeze_language_encoder=false \
|
||||||
--policy.train_policy_transformer=True \
|
--policy.train_policy_transformer=true \
|
||||||
--policy.train_soft_prompts=True \
|
--policy.train_soft_prompts=true \
|
||||||
--policy.action_mode=YOUR_ACTION_MODE
|
--policy.action_mode=YOUR_ACTION_MODE
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training Parameters Explained
|
### Training Parameters Explained
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
| -------------------------- | ------- | ---------------------------------------- |
|
| -------------------------- | ------- | ---------------------------------------------- |
|
||||||
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
|
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||||
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
|
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||||
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
|
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||||
| `train_soft_prompts` | `True` | Allow soft prompts to train |
|
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||||
|
|
||||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
|
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||||
|
|
||||||
### Example: Training on Bimanual Robot
|
### Example: Training on Bimanual Robot
|
||||||
|
|
||||||
@@ -157,14 +158,15 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/xvla_bimanual \
|
--output_dir=./outputs/xvla_bimanual \
|
||||||
--job_name=xvla_so101_training \
|
--job_name=xvla_so101_training \
|
||||||
--policy.path="lerobot/xvla-base" \
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||||
--steps=3000 \
|
--steps=3000 \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--policy.action_mode=so101_bimanual \
|
--policy.action_mode=so101_bimanual \
|
||||||
--policy.freeze_vision_encoder=True \
|
--policy.freeze_vision_encoder=false \
|
||||||
--policy.freeze_language_encoder=True \
|
--policy.freeze_language_encoder=false \
|
||||||
--policy.train_policy_transformer=True \
|
--policy.train_policy_transformer=true \
|
||||||
--policy.train_soft_prompts=True
|
--policy.train_soft_prompts=true
|
||||||
```
|
```
|
||||||
|
|
||||||
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||||
@@ -172,71 +174,7 @@ lerobot-train \
|
|||||||
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||||
|
|
||||||
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||||
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||||
To enable this behavior, you must:
|
|
||||||
|
|
||||||
1. Implement a custom optimizer and register it in your training config
|
|
||||||
|
|
||||||
```
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from lerobot.optim.optimizers import OptimizerConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
@OptimizerConfig.register_subclass("xvla-adamw")
|
|
||||||
@dataclass
|
|
||||||
class XVLAAdamW(OptimizerConfig):
|
|
||||||
lr: float = 1e-4
|
|
||||||
betas: tuple[float, float] = (0.9, 0.99)
|
|
||||||
eps: float = 1e-8
|
|
||||||
weight_decay: float = 0.0
|
|
||||||
grad_clip_norm: float = 10.0
|
|
||||||
|
|
||||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
||||||
"""
|
|
||||||
Expect `named_parameters()` as input.
|
|
||||||
Apply lr = lr / 10 for all VLM-related parameters.
|
|
||||||
"""
|
|
||||||
assert isinstance(params, dict), \
|
|
||||||
"Custom LR optimizer requires `named_parameters()` as inputs."
|
|
||||||
kwargs = asdict(self)
|
|
||||||
kwargs.pop("grad_clip_norm")
|
|
||||||
vlm_group, other_group = [], []
|
|
||||||
for name, p in params.items():
|
|
||||||
if not p.requires_grad:
|
|
||||||
continue
|
|
||||||
if "vlm" in name.lower():
|
|
||||||
vlm_group.append(p)
|
|
||||||
else:
|
|
||||||
other_group.append(p)
|
|
||||||
|
|
||||||
param_groups = [
|
|
||||||
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
|
|
||||||
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
|
|
||||||
]
|
|
||||||
|
|
||||||
return torch.optim.AdamW(param_groups, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Modify X-VLA’s get_optim_params to return named parameters
|
|
||||||
|
|
||||||
Replace:
|
|
||||||
|
|
||||||
```
|
|
||||||
def get_optim_params(self) -> dict:
|
|
||||||
"""Return only trainable parameters for optimization."""
|
|
||||||
return filter(lambda p: p.requires_grad, self.parameters())
|
|
||||||
```
|
|
||||||
|
|
||||||
with:
|
|
||||||
|
|
||||||
```
|
|
||||||
def get_optim_params(self):
|
|
||||||
"""Return trainable named parameters."""
|
|
||||||
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
|
|
||||||
```
|
|
||||||
|
|
||||||
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
|
|
||||||
|
|
||||||
❕Note
|
❕Note
|
||||||
|
|
||||||
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||||
@@ -326,6 +264,26 @@ domain_id = 3
|
|||||||
|
|
||||||
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||||
|
|
||||||
|
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||||
|
|
||||||
|
#### Fine-tuning Datasets
|
||||||
|
|
||||||
|
| Dataset Name | Domain ID |
|
||||||
|
| ---------------- | --------- |
|
||||||
|
| Bridge | 0 |
|
||||||
|
| RT1 | 1 |
|
||||||
|
| Calvin | 2 |
|
||||||
|
| libero | 3 |
|
||||||
|
| widowx-air | 4 |
|
||||||
|
| AIR-AGILEX-HQ | 5 |
|
||||||
|
| robotwin2_abs_ee | 6 |
|
||||||
|
| robotwin2_clean | 6 |
|
||||||
|
| robocasa-human | 7 |
|
||||||
|
| VLABench | 8 |
|
||||||
|
| AGIBOT-challenge | 9 |
|
||||||
|
| AIR-AGILEX | 10 |
|
||||||
|
| AIRBOT | 18 |
|
||||||
|
|
||||||
### 3. Processor Steps
|
### 3. Processor Steps
|
||||||
|
|
||||||
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||||
|
|||||||
Reference in New Issue
Block a user