feat(pi0): add train_expert_only and freeze_vision_encoder flags to pi0 and pi0.5 (#2727)

* feat(pi0): add train_expert_only and freeze_vision_encoder options

* pi_05: train_expert_only and freeze_vision_encoder flags

* comment clean up

* docs: add finetuning parameters to pi0 and pi05 docs

* updating docs to follow standards
This commit is contained in:
arya
2025-12-31 06:54:28 -08:00
committed by GitHub
parent 6d0d65a5fe
commit 9701b9c273
6 changed files with 78 additions and 0 deletions
+11
View File
@@ -64,6 +64,8 @@ python src/lerobot/scripts/lerobot_train.py \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
@@ -79,6 +81,15 @@ python src/lerobot/scripts/lerobot_train.py \
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
### Training Parameters Explained
| Parameter | Default | Description |
| ----------------------- | ------- | ------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+11
View File
@@ -67,6 +67,8 @@ python src/lerobot/scripts/lerobot_train.py\
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
@@ -82,6 +84,15 @@ python src/lerobot/scripts/lerobot_train.py\
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
### Training Parameters Explained
| Parameter | Default | Description |
| ----------------------- | ------- | ------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
@@ -76,6 +76,10 @@ class PI0Config(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW``
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
+24
View File
@@ -339,10 +339,14 @@ class PaliGemmaWithExpertModel(
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -383,6 +387,7 @@ class PaliGemmaWithExpertModel(
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
@@ -406,6 +411,23 @@ class PaliGemmaWithExpertModel(
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
@@ -533,6 +555,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, False],
precision=config.dtype,
image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -76,6 +76,10 @@ class PI05Config(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
@@ -337,10 +337,14 @@ class PaliGemmaWithExpertModel(
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -381,6 +385,7 @@ class PaliGemmaWithExpertModel(
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
@@ -404,6 +409,23 @@ class PaliGemmaWithExpertModel(
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
@@ -531,6 +553,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, True],
precision=config.dtype,
image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)