add freeze/unfreeze options

This commit is contained in:
Jade Choghari
2025-11-24 14:11:23 +01:00
parent 8f2321af27
commit 722766b825
3 changed files with 139 additions and 4 deletions
@@ -84,6 +84,13 @@ class XVLAConfig(PreTrainedConfig):
num_image_views: int | None = None
empty_cameras: int = 0
# Freezing options for VLM components
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
freeze_vision_encoder: bool = True # Freeze VLM vision encoder weights
freeze_language_encoder: bool = True # Freeze VLM language encoder weights
train_policy_transformer: bool = True # Allow policy transformer to train
train_soft_prompts: bool = True # Allow soft prompts to train
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
+62 -1
View File
@@ -83,6 +83,55 @@ class XVLAModel(nn.Module):
use_hetero_proj=config.use_hetero_proj,
)
# Apply freezing based on config
self._apply_freezing()
def _apply_freezing(self) -> None:
"""
Freeze VLM vision and language encoders based on config options.
Keep only policy transformer and soft prompts trainable.
"""
# Freeze vision encoder
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
for param in self.vlm.vision_tower.parameters():
param.requires_grad = False
# Freeze language encoder
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
# Freeze encoder
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
for param in lm.model.encoder.parameters():
param.requires_grad = False
# Freeze shared embeddings
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
for param in lm.model.shared.parameters():
param.requires_grad = False
# Freeze or unfreeze policy transformer
if not self.config.train_policy_transformer:
for name, param in self.transformer.named_parameters():
if "soft_prompts" not in name:
param.requires_grad = False
# Freeze or unfreeze soft prompts
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
for param in self.transformer.soft_prompt_hub.parameters():
param.requires_grad = False
def get_trainable_params_summary(self) -> dict[str, int]:
"""
Returns a summary of trainable vs frozen parameters.
"""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
return {
"trainable": trainable,
"frozen": frozen,
"total": trainable + frozen,
"trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0,
}
def forward_vlm(
self,
input_ids: torch.LongTensor,
@@ -197,13 +246,25 @@ class XVLAPolicy(PreTrainedPolicy):
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset()
# Log trainable parameters summary
params_summary = self.model.get_trainable_params_summary()
print("XVLA Parameter Summary:")
print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)")
print(f" Frozen: {params_summary['frozen']:,}")
print(f" Total: {params_summary['total']:,}")
print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}")
print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}")
print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}")
print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}")
def reset(self) -> None:
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
return self.parameters()
"""Return only trainable parameters for optimization."""
return filter(lambda p: p.requires_grad, self.parameters())
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
if not self.config.use_proprio or OBS_STATE not in batch:
+70 -3
View File
@@ -1,5 +1,72 @@
import math
import numpy as np
import robosuite.utils.transform_utils as transform_utils
def mat2quat(rmat):
"""
Converts given rotation matrix to quaternion.
Args:
rmat (np.array): 3x3 rotation matrix
Returns:
np.array: (x,y,z,w) float quaternion angles
"""
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
m00 = mat[0, 0]
m01 = mat[0, 1]
m02 = mat[0, 2]
m10 = mat[1, 0]
m11 = mat[1, 1]
m12 = mat[1, 2]
m20 = mat[2, 0]
m21 = mat[2, 1]
m22 = mat[2, 2]
# symmetric matrix k
k = np.array(
[
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
]
)
k /= 3.0
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
w, v = np.linalg.eigh(k)
inds = np.array([3, 0, 1, 2])
q1 = v[inds, np.argmax(w)]
if q1[0] < 0.0:
np.negative(q1, q1)
inds = np.array([1, 2, 3, 0])
return q1[inds]
def quat2axisangle(quat):
"""
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def rotate6d_to_axis_angle(r6d):
@@ -30,8 +97,8 @@ def rotate6d_to_axis_angle(r6d):
axis_angle_list = []
for i in range(rotation_matrix.shape[0]):
quat = transform_utils.mat2quat(rotation_matrix[i])
axis_angle = transform_utils.quat2axisangle(quat)
quat = mat2quat(rotation_matrix[i])
axis_angle = quat2axisangle(quat)
axis_angle_list.append(axis_angle)
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)