diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py index 0de151ec7..0caf35708 100644 --- a/src/lerobot/policies/xvla/action_hub.py +++ b/src/lerobot/policies/xvla/action_hub.py @@ -425,32 +425,25 @@ class BimanualSO101ActionSpace(BaseActionSpace): # Ensure both are [B, T, 20] pred = self._pad_to_model_dim(pred) target = self._pad_to_model_dim(target) - assert pred.shape == target.shape, "pred/target shapes must match" - batch_size, seq_len, action_dim = pred.shape - _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx") + assert pred.shape == target.shape - # --- Gripper BCE loss (only real gripper indices) --- - g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] - gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE - - # --- Joint positions MSE (all non-gripper *real* dims) --- - real_set = set(self.REAL_IDXS) - joints_idx = tuple(i for i in real_set if i not in set(self.gripper_idx)) + # ---- MSE for all real dims (0–11) ---- + REAL_DIMS = 12 joints_loss = self.mse( - pred[:, :, joints_idx], - target[:, :, joints_idx], + pred[:, :, :REAL_DIMS], + target[:, :, :REAL_DIMS], ) * self.JOINTS_SCALE - # Separate losses for left and right arms for better monitoring - left_arm_loss = self.mse( - pred[:, :, self.LEFT_ARM_JOINTS], - target[:, :, self.LEFT_ARM_JOINTS], - ) - right_arm_loss = self.mse( - pred[:, :, self.RIGHT_ARM_JOINTS], - target[:, :, self.RIGHT_ARM_JOINTS], - ) + left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6]) + right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12]) + + # is gripper continuous? not bce? + gripper_loss = self.mse( + pred[:, :, [5, 11]], + target[:, :, [5, 11]], + ) * self.GRIPPER_SCALE + return { "joints_loss": joints_loss, "gripper_loss": gripper_loss,