diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index 60ebfe911..2ecba245f 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -84,6 +84,13 @@ class XVLAConfig(PreTrainedConfig): num_image_views: int | None = None empty_cameras: int = 0 + # Freezing options for VLM components + # By default, VLM encoders are frozen and only policy transformer + soft prompts train + freeze_vision_encoder: bool = True # Freeze VLM vision encoder weights + freeze_language_encoder: bool = True # Freeze VLM language encoder weights + train_policy_transformer: bool = True # Allow policy transformer to train + train_soft_prompts: bool = True # Allow soft prompts to train + # Training presets optimizer_lr: float = 1e-4 optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 444dc0552..fd5dc6401 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -83,6 +83,55 @@ class XVLAModel(nn.Module): use_hetero_proj=config.use_hetero_proj, ) + # Apply freezing based on config + self._apply_freezing() + + def _apply_freezing(self) -> None: + """ + Freeze VLM vision and language encoders based on config options. + Keep only policy transformer and soft prompts trainable. + """ + # Freeze vision encoder + if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"): + for param in self.vlm.vision_tower.parameters(): + param.requires_grad = False + + # Freeze language encoder + if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"): + lm = self.vlm.language_model + # Freeze encoder + if hasattr(lm, "model") and hasattr(lm.model, "encoder"): + for param in lm.model.encoder.parameters(): + param.requires_grad = False + # Freeze shared embeddings + if hasattr(lm, "model") and hasattr(lm.model, "shared"): + for param in lm.model.shared.parameters(): + param.requires_grad = False + + # Freeze or unfreeze policy transformer + if not self.config.train_policy_transformer: + for name, param in self.transformer.named_parameters(): + if "soft_prompts" not in name: + param.requires_grad = False + + # Freeze or unfreeze soft prompts + if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"): + for param in self.transformer.soft_prompt_hub.parameters(): + param.requires_grad = False + + def get_trainable_params_summary(self) -> dict[str, int]: + """ + Returns a summary of trainable vs frozen parameters. + """ + trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) + return { + "trainable": trainable, + "frozen": frozen, + "total": trainable + frozen, + "trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0, + } + def forward_vlm( self, input_ids: torch.LongTensor, @@ -197,13 +246,25 @@ class XVLAPolicy(PreTrainedPolicy): self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim) self.reset() + # Log trainable parameters summary + params_summary = self.model.get_trainable_params_summary() + print("XVLA Parameter Summary:") + print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)") + print(f" Frozen: {params_summary['frozen']:,}") + print(f" Total: {params_summary['total']:,}") + print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}") + print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}") + print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}") + print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}") + def reset(self) -> None: self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), } def get_optim_params(self) -> dict: - return self.parameters() + """Return only trainable parameters for optimization.""" + return filter(lambda p: p.requires_grad, self.parameters()) def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor: if not self.config.use_proprio or OBS_STATE not in batch: diff --git a/src/lerobot/policies/xvla/utils.py b/src/lerobot/policies/xvla/utils.py index 38e3e1f20..73793981e 100644 --- a/src/lerobot/policies/xvla/utils.py +++ b/src/lerobot/policies/xvla/utils.py @@ -1,5 +1,72 @@ +import math + import numpy as np -import robosuite.utils.transform_utils as transform_utils + + +def mat2quat(rmat): + """ + Converts given rotation matrix to quaternion. + + Args: + rmat (np.array): 3x3 rotation matrix + + Returns: + np.array: (x,y,z,w) float quaternion angles + """ + mat = np.asarray(rmat).astype(np.float32)[:3, :3] + + m00 = mat[0, 0] + m01 = mat[0, 1] + m02 = mat[0, 2] + m10 = mat[1, 0] + m11 = mat[1, 1] + m12 = mat[1, 2] + m20 = mat[2, 0] + m21 = mat[2, 1] + m22 = mat[2, 2] + # symmetric matrix k + k = np.array( + [ + [m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)], + [m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)], + [m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)], + [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22], + ] + ) + k /= 3.0 + # quaternion is Eigen vector of k that corresponds to largest eigenvalue + w, v = np.linalg.eigh(k) + inds = np.array([3, 0, 1, 2]) + q1 = v[inds, np.argmax(w)] + if q1[0] < 0.0: + np.negative(q1, q1) + inds = np.array([1, 2, 3, 0]) + return q1[inds] + + +def quat2axisangle(quat): + """ + Converts quaternion to axis-angle format. + Returns a unit vector direction scaled by its angle in radians. + + Args: + quat (np.array): (x,y,z,w) vec4 float angles + + Returns: + np.array: (ax,ay,az) axis-angle exponential coordinates + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den def rotate6d_to_axis_angle(r6d): @@ -30,8 +97,8 @@ def rotate6d_to_axis_angle(r6d): axis_angle_list = [] for i in range(rotation_matrix.shape[0]): - quat = transform_utils.mat2quat(rotation_matrix[i]) - axis_angle = transform_utils.quat2axisangle(quat) + quat = mat2quat(rotation_matrix[i]) + axis_angle = quat2axisangle(quat) axis_angle_list.append(axis_angle) axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)