diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7271d7c9a..0ecfa169b 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -52,6 +52,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to the original scale. """ + n_obs_steps: int = 1 input_features: dict[str, PolicyFeature] = field(default_factory=dict) @@ -202,7 +203,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno with open(config_file) as f: config = json.load(f) - + config.pop("type") with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f: json.dump(config, f) diff --git a/src/lerobot/policies/xvla/__init__.py b/src/lerobot/policies/xvla/__init__.py index afc7e0142..356eb22e4 100644 --- a/src/lerobot/policies/xvla/__init__.py +++ b/src/lerobot/policies/xvla/__init__.py @@ -1,2 +1,6 @@ # add domainid -from lerobot.policies.xvla.processor_xvla import XVLAAddDomainIdProcessorStep, XVLAImageNetNormalizeProcessorStep, XVLAImageToFloatProcessorStep +from lerobot.policies.xvla.processor_xvla import ( + XVLAAddDomainIdProcessorStep, + XVLAImageNetNormalizeProcessorStep, + XVLAImageToFloatProcessorStep, +) diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py index 0caf35708..80be6847d 100644 --- a/src/lerobot/policies/xvla/action_hub.py +++ b/src/lerobot/policies/xvla/action_hub.py @@ -293,69 +293,6 @@ class FrankaJoint7ActionSpace(BaseActionSpace): return action -@register_action("so101_bimanual_old") -class BimanualSO101OldActionSpace(BaseActionSpace): - """ - Bimanual SO101 robot: 2 arms with 5 joints each + gripper. - - Layout: [left_arm (5 joints + gripper), right_arm (5 joints + gripper)] - - Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper - - Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper - Total: 12 dimensions - """ - - dim_action = 12 - gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11 - GRIPPER_SCALE = 1.0 - JOINTS_SCALE = 1.0 - - # Indices for left and right arm joints (excluding grippers) - LEFT_ARM_JOINTS = (0, 1, 2, 3, 4) - RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10) - - def __init__(self): - super().__init__() - self.mse = nn.MSELoss() - self.bce = nn.BCEWithLogitsLoss() - - def compute_loss(self, pred, target): - assert pred.shape == target.shape, "pred/target shapes must match" - batch_size, seq_len, action_dim = pred.shape - _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx") - - # Gripper BCE loss (binary classification for open/close) - g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] - gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE - - # Joint positions MSE (all non-gripper dimensions) - joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx)) - joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE - - # Separate losses for left and right arms for better monitoring - left_arm_loss = self.mse(pred[:, :, self.LEFT_ARM_JOINTS], target[:, :, self.LEFT_ARM_JOINTS]) - right_arm_loss = self.mse(pred[:, :, self.RIGHT_ARM_JOINTS], target[:, :, self.RIGHT_ARM_JOINTS]) - - return { - "joints_loss": joints_loss, - "gripper_loss": gripper_loss, - "left_arm_loss": left_arm_loss, - "right_arm_loss": right_arm_loss, - } - - def preprocess(self, proprio, action, mode="train"): - """Zero-out gripper channels in proprio/action to focus learning on continuous joint control.""" - proprio_m = proprio.clone() - action_m = action.clone() - proprio_m[..., self.gripper_idx] = 0.0 - action_m[..., self.gripper_idx] = 0.0 - return proprio_m, action_m - - def postprocess(self, action: torch.Tensor) -> torch.Tensor: - """Apply sigmoid to gripper logits to convert to [0, 1] range.""" - if action.size(-1) > max(self.gripper_idx): - action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) - return action - @register_action("so101_bimanual") class BimanualSO101ActionSpace(BaseActionSpace): """ @@ -377,7 +314,7 @@ class BimanualSO101ActionSpace(BaseActionSpace): REAL_DIM = 12 # Indices of real vs dummy channels - REAL_IDXS = tuple(range(REAL_DIM)) # 0..11 + REAL_IDXS = tuple(range(REAL_DIM)) # 0..11 DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19 # Grippers live in the real part @@ -412,7 +349,7 @@ class BimanualSO101ActionSpace(BaseActionSpace): def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor: """Keep only the first REAL_DIM (12) dims for the real robot.""" - return x[..., :self.REAL_DIM] + return x[..., : self.REAL_DIM] # ---------- loss ---------- @@ -428,22 +365,28 @@ class BimanualSO101ActionSpace(BaseActionSpace): assert pred.shape == target.shape # ---- MSE for all real dims (0–11) ---- - REAL_DIMS = 12 + real_dims = 12 - joints_loss = self.mse( - pred[:, :, :REAL_DIMS], - target[:, :, :REAL_DIMS], - ) * self.JOINTS_SCALE + joints_loss = ( + self.mse( + pred[:, :, :real_dims], + target[:, :, :real_dims], + ) + * self.JOINTS_SCALE + ) - left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6]) + left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6]) right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12]) # is gripper continuous? not bce? - gripper_loss = self.mse( - pred[:, :, [5, 11]], - target[:, :, [5, 11]], - ) * self.GRIPPER_SCALE - + gripper_loss = ( + self.mse( + pred[:, :, [5, 11]], + target[:, :, [5, 11]], + ) + * self.GRIPPER_SCALE + ) + return { "joints_loss": joints_loss, "gripper_loss": gripper_loss, diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 8ae0cb1ca..444dc0552 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -359,8 +359,9 @@ class XVLAPolicy(PreTrainedPolicy): - skip list for layers that should remain randomly initialized """ import safetensors.torch + # step 1: load config - #TODO: jadechoghari, fix this + # TODO: jadechoghari, fix this if config is None: config = PreTrainedConfig.from_pretrained( pretrained_name_or_path=pretrained_name_or_path, @@ -373,7 +374,6 @@ class XVLAPolicy(PreTrainedPolicy): revision=revision, **kwargs, ) - model_id = str(pretrained_name_or_path) instance = cls(config, **kwargs)