mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
more
This commit is contained in:
@@ -425,32 +425,25 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
|||||||
# Ensure both are [B, T, 20]
|
# Ensure both are [B, T, 20]
|
||||||
pred = self._pad_to_model_dim(pred)
|
pred = self._pad_to_model_dim(pred)
|
||||||
target = self._pad_to_model_dim(target)
|
target = self._pad_to_model_dim(target)
|
||||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
assert pred.shape == target.shape
|
||||||
batch_size, seq_len, action_dim = pred.shape
|
|
||||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
|
||||||
|
|
||||||
# --- Gripper BCE loss (only real gripper indices) ---
|
# ---- MSE for all real dims (0–11) ----
|
||||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
REAL_DIMS = 12
|
||||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
|
||||||
|
|
||||||
# --- Joint positions MSE (all non-gripper *real* dims) ---
|
|
||||||
real_set = set(self.REAL_IDXS)
|
|
||||||
joints_idx = tuple(i for i in real_set if i not in set(self.gripper_idx))
|
|
||||||
|
|
||||||
joints_loss = self.mse(
|
joints_loss = self.mse(
|
||||||
pred[:, :, joints_idx],
|
pred[:, :, :REAL_DIMS],
|
||||||
target[:, :, joints_idx],
|
target[:, :, :REAL_DIMS],
|
||||||
) * self.JOINTS_SCALE
|
) * self.JOINTS_SCALE
|
||||||
|
|
||||||
# Separate losses for left and right arms for better monitoring
|
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||||
left_arm_loss = self.mse(
|
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||||
pred[:, :, self.LEFT_ARM_JOINTS],
|
|
||||||
target[:, :, self.LEFT_ARM_JOINTS],
|
# is gripper continuous? not bce?
|
||||||
)
|
gripper_loss = self.mse(
|
||||||
right_arm_loss = self.mse(
|
pred[:, :, [5, 11]],
|
||||||
pred[:, :, self.RIGHT_ARM_JOINTS],
|
target[:, :, [5, 11]],
|
||||||
target[:, :, self.RIGHT_ARM_JOINTS],
|
) * self.GRIPPER_SCALE
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"joints_loss": joints_loss,
|
"joints_loss": joints_loss,
|
||||||
"gripper_loss": gripper_loss,
|
"gripper_loss": gripper_loss,
|
||||||
|
|||||||
Reference in New Issue
Block a user