mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
more
This commit is contained in:
@@ -425,32 +425,25 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
# Ensure both are [B, T, 20]
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
assert pred.shape == target.shape
|
||||
|
||||
# --- Gripper BCE loss (only real gripper indices) ---
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
# --- Joint positions MSE (all non-gripper *real* dims) ---
|
||||
real_set = set(self.REAL_IDXS)
|
||||
joints_idx = tuple(i for i in real_set if i not in set(self.gripper_idx))
|
||||
# ---- MSE for all real dims (0–11) ----
|
||||
REAL_DIMS = 12
|
||||
|
||||
joints_loss = self.mse(
|
||||
pred[:, :, joints_idx],
|
||||
target[:, :, joints_idx],
|
||||
pred[:, :, :REAL_DIMS],
|
||||
target[:, :, :REAL_DIMS],
|
||||
) * self.JOINTS_SCALE
|
||||
|
||||
# Separate losses for left and right arms for better monitoring
|
||||
left_arm_loss = self.mse(
|
||||
pred[:, :, self.LEFT_ARM_JOINTS],
|
||||
target[:, :, self.LEFT_ARM_JOINTS],
|
||||
)
|
||||
right_arm_loss = self.mse(
|
||||
pred[:, :, self.RIGHT_ARM_JOINTS],
|
||||
target[:, :, self.RIGHT_ARM_JOINTS],
|
||||
)
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||
|
||||
# is gripper continuous? not bce?
|
||||
gripper_loss = self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
target[:, :, [5, 11]],
|
||||
) * self.GRIPPER_SCALE
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
|
||||
Reference in New Issue
Block a user