mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
add freeze/unfreeze options
This commit is contained in:
@@ -84,6 +84,13 @@ class XVLAConfig(PreTrainedConfig):
|
|||||||
num_image_views: int | None = None
|
num_image_views: int | None = None
|
||||||
empty_cameras: int = 0
|
empty_cameras: int = 0
|
||||||
|
|
||||||
|
# Freezing options for VLM components
|
||||||
|
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
|
||||||
|
freeze_vision_encoder: bool = True # Freeze VLM vision encoder weights
|
||||||
|
freeze_language_encoder: bool = True # Freeze VLM language encoder weights
|
||||||
|
train_policy_transformer: bool = True # Allow policy transformer to train
|
||||||
|
train_soft_prompts: bool = True # Allow soft prompts to train
|
||||||
|
|
||||||
# Training presets
|
# Training presets
|
||||||
optimizer_lr: float = 1e-4
|
optimizer_lr: float = 1e-4
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|||||||
@@ -83,6 +83,55 @@ class XVLAModel(nn.Module):
|
|||||||
use_hetero_proj=config.use_hetero_proj,
|
use_hetero_proj=config.use_hetero_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply freezing based on config
|
||||||
|
self._apply_freezing()
|
||||||
|
|
||||||
|
def _apply_freezing(self) -> None:
|
||||||
|
"""
|
||||||
|
Freeze VLM vision and language encoders based on config options.
|
||||||
|
Keep only policy transformer and soft prompts trainable.
|
||||||
|
"""
|
||||||
|
# Freeze vision encoder
|
||||||
|
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
|
||||||
|
for param in self.vlm.vision_tower.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze language encoder
|
||||||
|
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
|
||||||
|
lm = self.vlm.language_model
|
||||||
|
# Freeze encoder
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
|
||||||
|
for param in lm.model.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
# Freeze shared embeddings
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
|
||||||
|
for param in lm.model.shared.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze policy transformer
|
||||||
|
if not self.config.train_policy_transformer:
|
||||||
|
for name, param in self.transformer.named_parameters():
|
||||||
|
if "soft_prompts" not in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze soft prompts
|
||||||
|
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
|
||||||
|
for param in self.transformer.soft_prompt_hub.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def get_trainable_params_summary(self) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Returns a summary of trainable vs frozen parameters.
|
||||||
|
"""
|
||||||
|
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||||
|
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
|
||||||
|
return {
|
||||||
|
"trainable": trainable,
|
||||||
|
"frozen": frozen,
|
||||||
|
"total": trainable + frozen,
|
||||||
|
"trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
def forward_vlm(
|
def forward_vlm(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@@ -197,13 +246,25 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
# Log trainable parameters summary
|
||||||
|
params_summary = self.model.get_trainable_params_summary()
|
||||||
|
print("XVLA Parameter Summary:")
|
||||||
|
print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)")
|
||||||
|
print(f" Frozen: {params_summary['frozen']:,}")
|
||||||
|
print(f" Total: {params_summary['total']:,}")
|
||||||
|
print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}")
|
||||||
|
print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}")
|
||||||
|
print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}")
|
||||||
|
print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}")
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self._queues = {
|
self._queues = {
|
||||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
"""Return only trainable parameters for optimization."""
|
||||||
|
return filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
|
||||||
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||||
if not self.config.use_proprio or OBS_STATE not in batch:
|
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||||
|
|||||||
@@ -1,5 +1,72 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import robosuite.utils.transform_utils as transform_utils
|
|
||||||
|
|
||||||
|
def mat2quat(rmat):
|
||||||
|
"""
|
||||||
|
Converts given rotation matrix to quaternion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rmat (np.array): 3x3 rotation matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (x,y,z,w) float quaternion angles
|
||||||
|
"""
|
||||||
|
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
|
||||||
|
|
||||||
|
m00 = mat[0, 0]
|
||||||
|
m01 = mat[0, 1]
|
||||||
|
m02 = mat[0, 2]
|
||||||
|
m10 = mat[1, 0]
|
||||||
|
m11 = mat[1, 1]
|
||||||
|
m12 = mat[1, 2]
|
||||||
|
m20 = mat[2, 0]
|
||||||
|
m21 = mat[2, 1]
|
||||||
|
m22 = mat[2, 2]
|
||||||
|
# symmetric matrix k
|
||||||
|
k = np.array(
|
||||||
|
[
|
||||||
|
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
|
||||||
|
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
k /= 3.0
|
||||||
|
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
|
||||||
|
w, v = np.linalg.eigh(k)
|
||||||
|
inds = np.array([3, 0, 1, 2])
|
||||||
|
q1 = v[inds, np.argmax(w)]
|
||||||
|
if q1[0] < 0.0:
|
||||||
|
np.negative(q1, q1)
|
||||||
|
inds = np.array([1, 2, 3, 0])
|
||||||
|
return q1[inds]
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat):
|
||||||
|
"""
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
def rotate6d_to_axis_angle(r6d):
|
def rotate6d_to_axis_angle(r6d):
|
||||||
@@ -30,8 +97,8 @@ def rotate6d_to_axis_angle(r6d):
|
|||||||
|
|
||||||
axis_angle_list = []
|
axis_angle_list = []
|
||||||
for i in range(rotation_matrix.shape[0]):
|
for i in range(rotation_matrix.shape[0]):
|
||||||
quat = transform_utils.mat2quat(rotation_matrix[i])
|
quat = mat2quat(rotation_matrix[i])
|
||||||
axis_angle = transform_utils.quat2axisangle(quat)
|
axis_angle = quat2axisangle(quat)
|
||||||
axis_angle_list.append(axis_angle)
|
axis_angle_list.append(axis_angle)
|
||||||
|
|
||||||
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user