This commit is contained in:
Jade Choghari
2025-11-24 10:44:00 +01:00
parent 5052d4d70b
commit 8f2321af27
4 changed files with 28 additions and 80 deletions
+2 -1
View File
@@ -52,6 +52,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
@@ -202,7 +203,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
with open(config_file) as f:
config = json.load(f)
config.pop("type")
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(config, f)
+5 -1
View File
@@ -1,2 +1,6 @@
# add domainid
from lerobot.policies.xvla.processor_xvla import XVLAAddDomainIdProcessorStep, XVLAImageNetNormalizeProcessorStep, XVLAImageToFloatProcessorStep
from lerobot.policies.xvla.processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)
+19 -76
View File
@@ -293,69 +293,6 @@ class FrankaJoint7ActionSpace(BaseActionSpace):
return action
@register_action("so101_bimanual_old")
class BimanualSO101OldActionSpace(BaseActionSpace):
"""
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
Layout: [left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
Total: 12 dimensions
"""
dim_action = 12
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
GRIPPER_SCALE = 1.0
JOINTS_SCALE = 1.0
# Indices for left and right arm joints (excluding grippers)
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
# Gripper BCE loss (binary classification for open/close)
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# Joint positions MSE (all non-gripper dimensions)
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
# Separate losses for left and right arms for better monitoring
left_arm_loss = self.mse(pred[:, :, self.LEFT_ARM_JOINTS], target[:, :, self.LEFT_ARM_JOINTS])
right_arm_loss = self.mse(pred[:, :, self.RIGHT_ARM_JOINTS], target[:, :, self.RIGHT_ARM_JOINTS])
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
"left_arm_loss": left_arm_loss,
"right_arm_loss": right_arm_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action to focus learning on continuous joint control."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits to convert to [0, 1] range."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("so101_bimanual")
class BimanualSO101ActionSpace(BaseActionSpace):
"""
@@ -377,7 +314,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
REAL_DIM = 12
# Indices of real vs dummy channels
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
# Grippers live in the real part
@@ -412,7 +349,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Keep only the first REAL_DIM (12) dims for the real robot."""
return x[..., :self.REAL_DIM]
return x[..., : self.REAL_DIM]
# ---------- loss ----------
@@ -428,22 +365,28 @@ class BimanualSO101ActionSpace(BaseActionSpace):
assert pred.shape == target.shape
# ---- MSE for all real dims (011) ----
REAL_DIMS = 12
real_dims = 12
joints_loss = self.mse(
pred[:, :, :REAL_DIMS],
target[:, :, :REAL_DIMS],
) * self.JOINTS_SCALE
joints_loss = (
self.mse(
pred[:, :, :real_dims],
target[:, :, :real_dims],
)
* self.JOINTS_SCALE
)
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
# is gripper continuous? not bce?
gripper_loss = self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
) * self.GRIPPER_SCALE
gripper_loss = (
self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
)
* self.GRIPPER_SCALE
)
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
+2 -2
View File
@@ -359,8 +359,9 @@ class XVLAPolicy(PreTrainedPolicy):
- skip list for layers that should remain randomly initialized
"""
import safetensors.torch
# step 1: load config
#TODO: jadechoghari, fix this
# TODO: jadechoghari, fix this
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
@@ -373,7 +374,6 @@ class XVLAPolicy(PreTrainedPolicy):
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)