This commit is contained in:
Jade Choghari
2025-11-24 10:36:32 +01:00
parent 15188b0cf8
commit 5052d4d70b
+14 -21
View File
@@ -425,32 +425,25 @@ class BimanualSO101ActionSpace(BaseActionSpace):
# Ensure both are [B, T, 20]
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape, "pred/target shapes must match"
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
assert pred.shape == target.shape
# --- Gripper BCE loss (only real gripper indices) ---
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# --- Joint positions MSE (all non-gripper *real* dims) ---
real_set = set(self.REAL_IDXS)
joints_idx = tuple(i for i in real_set if i not in set(self.gripper_idx))
# ---- MSE for all real dims (011) ----
REAL_DIMS = 12
joints_loss = self.mse(
pred[:, :, joints_idx],
target[:, :, joints_idx],
pred[:, :, :REAL_DIMS],
target[:, :, :REAL_DIMS],
) * self.JOINTS_SCALE
# Separate losses for left and right arms for better monitoring
left_arm_loss = self.mse(
pred[:, :, self.LEFT_ARM_JOINTS],
target[:, :, self.LEFT_ARM_JOINTS],
)
right_arm_loss = self.mse(
pred[:, :, self.RIGHT_ARM_JOINTS],
target[:, :, self.RIGHT_ARM_JOINTS],
)
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
# is gripper continuous? not bce?
gripper_loss = self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
) * self.GRIPPER_SCALE
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,