mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
more
This commit is contained in:
@@ -52,6 +52,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||||
the original scale.
|
the original scale.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
|
|
||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
@@ -202,7 +203,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
|
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
config.pop("type")
|
config.pop("type")
|
||||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|||||||
@@ -1,2 +1,6 @@
|
|||||||
# add domainid
|
# add domainid
|
||||||
from lerobot.policies.xvla.processor_xvla import XVLAAddDomainIdProcessorStep, XVLAImageNetNormalizeProcessorStep, XVLAImageToFloatProcessorStep
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
XVLAAddDomainIdProcessorStep,
|
||||||
|
XVLAImageNetNormalizeProcessorStep,
|
||||||
|
XVLAImageToFloatProcessorStep,
|
||||||
|
)
|
||||||
|
|||||||
@@ -293,69 +293,6 @@ class FrankaJoint7ActionSpace(BaseActionSpace):
|
|||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
@register_action("so101_bimanual_old")
|
|
||||||
class BimanualSO101OldActionSpace(BaseActionSpace):
|
|
||||||
"""
|
|
||||||
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
|
||||||
|
|
||||||
Layout: [left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
|
||||||
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
|
||||||
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
|
||||||
Total: 12 dimensions
|
|
||||||
"""
|
|
||||||
|
|
||||||
dim_action = 12
|
|
||||||
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
|
||||||
GRIPPER_SCALE = 1.0
|
|
||||||
JOINTS_SCALE = 1.0
|
|
||||||
|
|
||||||
# Indices for left and right arm joints (excluding grippers)
|
|
||||||
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
|
||||||
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.mse = nn.MSELoss()
|
|
||||||
self.bce = nn.BCEWithLogitsLoss()
|
|
||||||
|
|
||||||
def compute_loss(self, pred, target):
|
|
||||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
|
||||||
batch_size, seq_len, action_dim = pred.shape
|
|
||||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
|
||||||
|
|
||||||
# Gripper BCE loss (binary classification for open/close)
|
|
||||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
|
||||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
|
||||||
|
|
||||||
# Joint positions MSE (all non-gripper dimensions)
|
|
||||||
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
|
||||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
|
||||||
|
|
||||||
# Separate losses for left and right arms for better monitoring
|
|
||||||
left_arm_loss = self.mse(pred[:, :, self.LEFT_ARM_JOINTS], target[:, :, self.LEFT_ARM_JOINTS])
|
|
||||||
right_arm_loss = self.mse(pred[:, :, self.RIGHT_ARM_JOINTS], target[:, :, self.RIGHT_ARM_JOINTS])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"joints_loss": joints_loss,
|
|
||||||
"gripper_loss": gripper_loss,
|
|
||||||
"left_arm_loss": left_arm_loss,
|
|
||||||
"right_arm_loss": right_arm_loss,
|
|
||||||
}
|
|
||||||
|
|
||||||
def preprocess(self, proprio, action, mode="train"):
|
|
||||||
"""Zero-out gripper channels in proprio/action to focus learning on continuous joint control."""
|
|
||||||
proprio_m = proprio.clone()
|
|
||||||
action_m = action.clone()
|
|
||||||
proprio_m[..., self.gripper_idx] = 0.0
|
|
||||||
action_m[..., self.gripper_idx] = 0.0
|
|
||||||
return proprio_m, action_m
|
|
||||||
|
|
||||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Apply sigmoid to gripper logits to convert to [0, 1] range."""
|
|
||||||
if action.size(-1) > max(self.gripper_idx):
|
|
||||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
|
||||||
return action
|
|
||||||
|
|
||||||
@register_action("so101_bimanual")
|
@register_action("so101_bimanual")
|
||||||
class BimanualSO101ActionSpace(BaseActionSpace):
|
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||||
"""
|
"""
|
||||||
@@ -377,7 +314,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
|||||||
REAL_DIM = 12
|
REAL_DIM = 12
|
||||||
|
|
||||||
# Indices of real vs dummy channels
|
# Indices of real vs dummy channels
|
||||||
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||||
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||||
|
|
||||||
# Grippers live in the real part
|
# Grippers live in the real part
|
||||||
@@ -412,7 +349,7 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
|||||||
|
|
||||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||||
return x[..., :self.REAL_DIM]
|
return x[..., : self.REAL_DIM]
|
||||||
|
|
||||||
# ---------- loss ----------
|
# ---------- loss ----------
|
||||||
|
|
||||||
@@ -428,22 +365,28 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
|||||||
assert pred.shape == target.shape
|
assert pred.shape == target.shape
|
||||||
|
|
||||||
# ---- MSE for all real dims (0–11) ----
|
# ---- MSE for all real dims (0–11) ----
|
||||||
REAL_DIMS = 12
|
real_dims = 12
|
||||||
|
|
||||||
joints_loss = self.mse(
|
joints_loss = (
|
||||||
pred[:, :, :REAL_DIMS],
|
self.mse(
|
||||||
target[:, :, :REAL_DIMS],
|
pred[:, :, :real_dims],
|
||||||
) * self.JOINTS_SCALE
|
target[:, :, :real_dims],
|
||||||
|
)
|
||||||
|
* self.JOINTS_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||||
|
|
||||||
# is gripper continuous? not bce?
|
# is gripper continuous? not bce?
|
||||||
gripper_loss = self.mse(
|
gripper_loss = (
|
||||||
pred[:, :, [5, 11]],
|
self.mse(
|
||||||
target[:, :, [5, 11]],
|
pred[:, :, [5, 11]],
|
||||||
) * self.GRIPPER_SCALE
|
target[:, :, [5, 11]],
|
||||||
|
)
|
||||||
|
* self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"joints_loss": joints_loss,
|
"joints_loss": joints_loss,
|
||||||
"gripper_loss": gripper_loss,
|
"gripper_loss": gripper_loss,
|
||||||
|
|||||||
@@ -359,8 +359,9 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
- skip list for layers that should remain randomly initialized
|
- skip list for layers that should remain randomly initialized
|
||||||
"""
|
"""
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
# step 1: load config
|
# step 1: load config
|
||||||
#TODO: jadechoghari, fix this
|
# TODO: jadechoghari, fix this
|
||||||
if config is None:
|
if config is None:
|
||||||
config = PreTrainedConfig.from_pretrained(
|
config = PreTrainedConfig.from_pretrained(
|
||||||
pretrained_name_or_path=pretrained_name_or_path,
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
@@ -373,7 +374,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
instance = cls(config, **kwargs)
|
instance = cls(config, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user