mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(pi0): add train_expert_only and freeze_vision_encoder flags to pi0 and pi0.5 (#2727)
* feat(pi0): add train_expert_only and freeze_vision_encoder options * pi_05: train_expert_only and freeze_vision_encoder flags * comment clean up * docs: add finetuning parameters to pi0 and pi05 docs * updating docs to follow standards
This commit is contained in:
@@ -64,6 +64,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
--policy.compile_model=true \
|
--policy.compile_model=true \
|
||||||
--policy.gradient_checkpointing=true \
|
--policy.gradient_checkpointing=true \
|
||||||
--policy.dtype=bfloat16 \
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.train_expert_only=false \
|
||||||
--steps=3000 \
|
--steps=3000 \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--batch_size=32
|
--batch_size=32
|
||||||
@@ -79,6 +81,15 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
||||||
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
||||||
|
|
||||||
|
### Training Parameters Explained
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| ----------------------- | ------- | ------------------------------------------- |
|
||||||
|
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
|
||||||
|
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
|
||||||
|
|
||||||
|
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ python src/lerobot/scripts/lerobot_train.py\
|
|||||||
--policy.gradient_checkpointing=true \
|
--policy.gradient_checkpointing=true \
|
||||||
--wandb.enable=true \
|
--wandb.enable=true \
|
||||||
--policy.dtype=bfloat16 \
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.train_expert_only=false \
|
||||||
--steps=3000 \
|
--steps=3000 \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--batch_size=32
|
--batch_size=32
|
||||||
@@ -82,6 +84,15 @@ python src/lerobot/scripts/lerobot_train.py\
|
|||||||
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
||||||
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
||||||
|
|
||||||
|
### Training Parameters Explained
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| ----------------------- | ------- | ------------------------------------------- |
|
||||||
|
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
|
||||||
|
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
|
||||||
|
|
||||||
|
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
|
||||||
|
|
||||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -76,6 +76,10 @@ class PI0Config(PreTrainedConfig):
|
|||||||
compile_mode: str = "max-autotune" # Torch compile mode
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||||
|
|
||||||
|
# Finetuning settings
|
||||||
|
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||||
|
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||||
|
|
||||||
# Optimizer settings: see openpi `AdamW``
|
# Optimizer settings: see openpi `AdamW``
|
||||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|||||||
@@ -339,10 +339,14 @@ class PaliGemmaWithExpertModel(
|
|||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
|
freeze_vision_encoder: bool = False,
|
||||||
|
train_expert_only: bool = False,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.freeze_vision_encoder = freeze_vision_encoder
|
||||||
|
self.train_expert_only = train_expert_only
|
||||||
|
|
||||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||||
@@ -383,6 +387,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
|
self._set_requires_grad()
|
||||||
|
|
||||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
@@ -406,6 +411,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
if any(selector in name for selector in params_to_keep_float32):
|
if any(selector in name for selector in params_to_keep_float32):
|
||||||
param.data = param.data.to(dtype=torch.float32)
|
param.data = param.data.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
def _set_requires_grad(self):
|
||||||
|
if self.freeze_vision_encoder:
|
||||||
|
self.paligemma.vision_tower.eval()
|
||||||
|
for param in self.paligemma.vision_tower.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if self.train_expert_only:
|
||||||
|
self.paligemma.eval()
|
||||||
|
for param in self.paligemma.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
super().train(mode)
|
||||||
|
if self.freeze_vision_encoder:
|
||||||
|
self.paligemma.vision_tower.eval()
|
||||||
|
if self.train_expert_only:
|
||||||
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
return self.paligemma.model.get_image_features(image)
|
||||||
|
|
||||||
@@ -533,6 +555,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
use_adarms=[False, False],
|
use_adarms=[False, False],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
image_size=config.image_resolution[0],
|
image_size=config.image_resolution[0],
|
||||||
|
freeze_vision_encoder=config.freeze_vision_encoder,
|
||||||
|
train_expert_only=config.train_expert_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
|
|||||||
@@ -76,6 +76,10 @@ class PI05Config(PreTrainedConfig):
|
|||||||
compile_mode: str = "max-autotune" # Torch compile mode
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||||
|
|
||||||
|
# Finetuning settings
|
||||||
|
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||||
|
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||||
|
|
||||||
# Optimizer settings: see openpi `AdamW`
|
# Optimizer settings: see openpi `AdamW`
|
||||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|||||||
@@ -337,10 +337,14 @@ class PaliGemmaWithExpertModel(
|
|||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
|
freeze_vision_encoder: bool = False,
|
||||||
|
train_expert_only: bool = False,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.freeze_vision_encoder = freeze_vision_encoder
|
||||||
|
self.train_expert_only = train_expert_only
|
||||||
|
|
||||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||||
@@ -381,6 +385,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
self.gemma_expert.model.embed_tokens = None
|
self.gemma_expert.model.embed_tokens = None
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
|
self._set_requires_grad()
|
||||||
|
|
||||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
@@ -404,6 +409,23 @@ class PaliGemmaWithExpertModel(
|
|||||||
if any(selector in name for selector in params_to_keep_float32):
|
if any(selector in name for selector in params_to_keep_float32):
|
||||||
param.data = param.data.to(dtype=torch.float32)
|
param.data = param.data.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
def _set_requires_grad(self):
|
||||||
|
if self.freeze_vision_encoder:
|
||||||
|
self.paligemma.vision_tower.eval()
|
||||||
|
for param in self.paligemma.vision_tower.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if self.train_expert_only:
|
||||||
|
self.paligemma.eval()
|
||||||
|
for param in self.paligemma.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
super().train(mode)
|
||||||
|
if self.freeze_vision_encoder:
|
||||||
|
self.paligemma.vision_tower.eval()
|
||||||
|
if self.train_expert_only:
|
||||||
|
self.paligemma.eval()
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.paligemma.model.get_image_features(image)
|
return self.paligemma.model.get_image_features(image)
|
||||||
|
|
||||||
@@ -531,6 +553,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
use_adarms=[False, True],
|
use_adarms=[False, True],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
image_size=config.image_resolution[0],
|
image_size=config.image_resolution[0],
|
||||||
|
freeze_vision_encoder=config.freeze_vision_encoder,
|
||||||
|
train_expert_only=config.train_expert_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
|
|||||||
Reference in New Issue
Block a user