mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
more
This commit is contained in:
@@ -52,6 +52,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
@@ -202,7 +203,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
config.pop("type")
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
# add domainid
|
||||
from lerobot.policies.xvla.processor_xvla import XVLAAddDomainIdProcessorStep, XVLAImageNetNormalizeProcessorStep, XVLAImageToFloatProcessorStep
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAImageToFloatProcessorStep,
|
||||
)
|
||||
|
||||
@@ -293,69 +293,6 @@ class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||
return action
|
||||
|
||||
|
||||
@register_action("so101_bimanual_old")
|
||||
class BimanualSO101OldActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
||||
|
||||
Layout: [left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
||||
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
Total: 12 dimensions
|
||||
"""
|
||||
|
||||
dim_action = 12
|
||||
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
||||
GRIPPER_SCALE = 1.0
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
# Indices for left and right arm joints (excluding grippers)
|
||||
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
||||
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
# Gripper BCE loss (binary classification for open/close)
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
# Joint positions MSE (all non-gripper dimensions)
|
||||
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||
|
||||
# Separate losses for left and right arms for better monitoring
|
||||
left_arm_loss = self.mse(pred[:, :, self.LEFT_ARM_JOINTS], target[:, :, self.LEFT_ARM_JOINTS])
|
||||
right_arm_loss = self.mse(pred[:, :, self.RIGHT_ARM_JOINTS], target[:, :, self.RIGHT_ARM_JOINTS])
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
"left_arm_loss": left_arm_loss,
|
||||
"right_arm_loss": right_arm_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action to focus learning on continuous joint control."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits to convert to [0, 1] range."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
@register_action("so101_bimanual")
|
||||
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
"""
|
||||
@@ -377,7 +314,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
REAL_DIM = 12
|
||||
|
||||
# Indices of real vs dummy channels
|
||||
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||
|
||||
# Grippers live in the real part
|
||||
@@ -412,7 +349,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||
return x[..., :self.REAL_DIM]
|
||||
return x[..., : self.REAL_DIM]
|
||||
|
||||
# ---------- loss ----------
|
||||
|
||||
@@ -428,22 +365,28 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
assert pred.shape == target.shape
|
||||
|
||||
# ---- MSE for all real dims (0–11) ----
|
||||
REAL_DIMS = 12
|
||||
real_dims = 12
|
||||
|
||||
joints_loss = self.mse(
|
||||
pred[:, :, :REAL_DIMS],
|
||||
target[:, :, :REAL_DIMS],
|
||||
) * self.JOINTS_SCALE
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, :real_dims],
|
||||
target[:, :, :real_dims],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||
|
||||
# is gripper continuous? not bce?
|
||||
gripper_loss = self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
target[:, :, [5, 11]],
|
||||
) * self.GRIPPER_SCALE
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
target[:, :, [5, 11]],
|
||||
)
|
||||
* self.GRIPPER_SCALE
|
||||
)
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
|
||||
@@ -359,8 +359,9 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
- skip list for layers that should remain randomly initialized
|
||||
"""
|
||||
import safetensors.torch
|
||||
|
||||
# step 1: load config
|
||||
#TODO: jadechoghari, fix this
|
||||
# TODO: jadechoghari, fix this
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
@@ -373,7 +374,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user