mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
add franka action
This commit is contained in:
@@ -39,8 +39,8 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
|||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.utils import validate_visual_features_consistency
|
from lerobot.policies.utils import validate_visual_features_consistency
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
batch_to_transition,
|
batch_to_transition,
|
||||||
|
|||||||
@@ -15,18 +15,21 @@
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Iterable, Tuple, Dict, Type
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Registry
|
# Registry
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
|
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_action(name: str):
|
def register_action(name: str):
|
||||||
"""Decorator for registering a new action space."""
|
"""Decorator for registering a new action space."""
|
||||||
|
|
||||||
def _wrap(cls):
|
def _wrap(cls):
|
||||||
key = name.lower()
|
key = name.lower()
|
||||||
if key in ACTION_REGISTRY:
|
if key in ACTION_REGISTRY:
|
||||||
@@ -34,10 +37,11 @@ def register_action(name: str):
|
|||||||
ACTION_REGISTRY[key] = cls
|
ACTION_REGISTRY[key] = cls
|
||||||
cls.name = key
|
cls.name = key
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return _wrap
|
return _wrap
|
||||||
|
|
||||||
|
|
||||||
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
|
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||||
"""Instantiate a registered action space by name."""
|
"""Instantiate a registered action space by name."""
|
||||||
key = name.lower()
|
key = name.lower()
|
||||||
if key not in ACTION_REGISTRY:
|
if key not in ACTION_REGISTRY:
|
||||||
@@ -62,7 +66,7 @@ class BaseActionSpace(nn.Module):
|
|||||||
|
|
||||||
name: str = "base"
|
name: str = "base"
|
||||||
dim_action: int = 0
|
dim_action: int = 0
|
||||||
gripper_idx: Tuple[int, ...] = ()
|
gripper_idx: tuple[int, ...] = ()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -70,10 +74,10 @@ class BaseActionSpace(nn.Module):
|
|||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
# Core supervised loss
|
# Core supervised loss
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
"""Alias for compute_loss."""
|
"""Alias for compute_loss."""
|
||||||
return self.compute_loss(pred, target)
|
return self.compute_loss(pred, target)
|
||||||
|
|
||||||
@@ -85,7 +89,7 @@ class BaseActionSpace(nn.Module):
|
|||||||
proprio: torch.Tensor,
|
proprio: torch.Tensor,
|
||||||
action: torch.Tensor,
|
action: torch.Tensor,
|
||||||
mode: str = "train",
|
mode: str = "train",
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Default: return unchanged."""
|
"""Default: return unchanged."""
|
||||||
return proprio, action
|
return proprio, action
|
||||||
|
|
||||||
@@ -137,14 +141,14 @@ class EE6DActionSpace(BaseActionSpace):
|
|||||||
|
|
||||||
# XYZ position
|
# XYZ position
|
||||||
pos_loss = (
|
pos_loss = (
|
||||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
) * self.XYZ_SCALE
|
) * self.XYZ_SCALE
|
||||||
|
|
||||||
# Rotation 6D
|
# Rotation 6D
|
||||||
rot_loss = (
|
rot_loss = (
|
||||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
) * self.ROT_SCALE
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -236,14 +240,16 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
|
|||||||
B, T, D = pred.shape
|
B, T, D = pred.shape
|
||||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
gripper_loss = (
|
||||||
|
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
pos_loss = (
|
pos_loss = (
|
||||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
) * self.XYZ_SCALE
|
) * self.XYZ_SCALE
|
||||||
rot_loss = (
|
rot_loss = (
|
||||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
) * self.ROT_SCALE
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -261,6 +267,32 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
|
|||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("franka_joint7")
|
||||||
|
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||||
|
"""Franka Panda joint-space: 7 joints, no gripper."""
|
||||||
|
|
||||||
|
dim_action = 7
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||||
|
B, T, D = pred.shape
|
||||||
|
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
|
||||||
|
return {"joints_loss": joints_loss}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""No preprocessing needed for 7 joint actions."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return directly (no sigmoid since no gripper)."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Exports
|
# Exports
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -271,5 +303,6 @@ __all__ = [
|
|||||||
"EE6DActionSpace",
|
"EE6DActionSpace",
|
||||||
"JointActionSpace",
|
"JointActionSpace",
|
||||||
"AGIBOTEE6DActionSpace",
|
"AGIBOTEE6DActionSpace",
|
||||||
|
"FrankaJoint7ActionSpace",
|
||||||
"ACTION_REGISTRY",
|
"ACTION_REGISTRY",
|
||||||
]
|
]
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,16 +11,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
""" Florence-2 configuration"""
|
""" Florence-2 configuration"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import AutoConfig
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Florence2VisionConfig(PretrainedConfig):
|
class Florence2VisionConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||||
@@ -118,7 +117,6 @@ class Florence2VisionConfig(PretrainedConfig):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Florence2LanguageConfig(PretrainedConfig):
|
class Florence2LanguageConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||||
@@ -269,6 +267,7 @@ class Florence2LanguageConfig(PretrainedConfig):
|
|||||||
"The config can simply be saved and uploaded again to be fixed."
|
"The config can simply be saved and uploaded again to be fixed."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Florence2Config(PretrainedConfig):
|
class Florence2Config(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||||
@@ -335,5 +334,4 @@ class Florence2Config(PretrainedConfig):
|
|||||||
if text_config is not None:
|
if text_config is not None:
|
||||||
self.text_config = Florence2LanguageConfig(**text_config)
|
self.text_config = Florence2LanguageConfig(**text_config)
|
||||||
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -124,43 +124,37 @@ class XVLAConfig(PreTrainedConfig):
|
|||||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||||
config_dict = dict(self.florence_config)
|
config_dict = dict(self.florence_config)
|
||||||
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
|
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||||
# Provide default vision config
|
# Provide default vision config
|
||||||
config_dict['vision_config'] = {
|
config_dict["vision_config"] = {
|
||||||
'model_type': 'davit',
|
"model_type": "davit",
|
||||||
'drop_path_rate': 0.1,
|
"drop_path_rate": 0.1,
|
||||||
'patch_size': [14, 7, 7, 7],
|
"patch_size": [14, 7, 7, 7],
|
||||||
'patch_stride': [4, 2, 2, 2],
|
"patch_stride": [4, 2, 2, 2],
|
||||||
'patch_padding': [3, 1, 1, 1],
|
"patch_padding": [3, 1, 1, 1],
|
||||||
'patch_prenorm': [False, True, True, True],
|
"patch_prenorm": [False, True, True, True],
|
||||||
'dim_embed': [256, 512, 1024, 2048],
|
"dim_embed": [256, 512, 1024, 2048],
|
||||||
'num_heads': [8, 16, 32, 64],
|
"num_heads": [8, 16, 32, 64],
|
||||||
'num_groups': [8, 16, 32, 64],
|
"num_groups": [8, 16, 32, 64],
|
||||||
'depths': [1, 1, 9, 1],
|
"depths": [1, 1, 9, 1],
|
||||||
'window_size': 12,
|
"window_size": 12,
|
||||||
'projection_dim': 1024,
|
"projection_dim": 1024,
|
||||||
'visual_temporal_embedding': {
|
"visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100},
|
||||||
'type': 'COSINE',
|
"image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50},
|
||||||
'max_temporal_embeddings': 100
|
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"],
|
||||||
},
|
|
||||||
'image_pos_embed': {
|
|
||||||
'type': 'learned_abs_2d',
|
|
||||||
'max_pos_embeddings': 50
|
|
||||||
},
|
|
||||||
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
|
|
||||||
}
|
}
|
||||||
if 'text_config' not in config_dict or config_dict['text_config'] is None:
|
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||||
# Provide default text config
|
# Provide default text config
|
||||||
config_dict['text_config'] = {
|
config_dict["text_config"] = {
|
||||||
'model_type': 'florence2_language',
|
"model_type": "florence2_language",
|
||||||
'vocab_size': 51289,
|
"vocab_size": 51289,
|
||||||
'd_model': 1024,
|
"d_model": 1024,
|
||||||
'encoder_layers': 12,
|
"encoder_layers": 12,
|
||||||
'decoder_layers': 12,
|
"decoder_layers": 12,
|
||||||
'encoder_attention_heads': 16,
|
"encoder_attention_heads": 16,
|
||||||
'decoder_attention_heads': 16,
|
"decoder_attention_heads": 16,
|
||||||
'encoder_ffn_dim': 4096,
|
"encoder_ffn_dim": 4096,
|
||||||
'decoder_ffn_dim': 4096,
|
"decoder_ffn_dim": 4096,
|
||||||
}
|
}
|
||||||
self._florence_config_obj = Florence2Config(**config_dict)
|
self._florence_config_obj = Florence2Config(**config_dict)
|
||||||
return self._florence_config_obj
|
return self._florence_config_obj
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
@@ -87,7 +86,7 @@ class XVLAModel(nn.Module):
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_mask: torch.Tensor,
|
image_mask: torch.Tensor,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Encode text and multi-view images via Florence2 encoder.
|
Encode text and multi-view images via Florence2 encoder.
|
||||||
"""
|
"""
|
||||||
@@ -129,13 +128,14 @@ class XVLAModel(nn.Module):
|
|||||||
domain_id: torch.LongTensor,
|
domain_id: torch.LongTensor,
|
||||||
proprio: torch.Tensor,
|
proprio: torch.Tensor,
|
||||||
action: torch.Tensor,
|
action: torch.Tensor,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||||
|
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
|
t = (
|
||||||
1 - 1e-5
|
torch.rand(1, device=input_ids.device)
|
||||||
)
|
+ torch.arange(batch_size, device=input_ids.device) / batch_size
|
||||||
|
) % (1 - 1e-5)
|
||||||
|
|
||||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||||
@@ -350,7 +350,9 @@ def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float
|
|||||||
ratio = max(current_width / width, current_height / height)
|
ratio = max(current_width / width, current_height / height)
|
||||||
resized_height = int(current_height / ratio)
|
resized_height = int(current_height / ratio)
|
||||||
resized_width = int(current_width / ratio)
|
resized_width = int(current_width / ratio)
|
||||||
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
|
resized_img = F.interpolate(
|
||||||
|
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
pad_height = max(0, height - resized_height)
|
pad_height = max(0, height - resized_height)
|
||||||
pad_width = max(0, width - resized_width)
|
pad_width = max(0, width - resized_width)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
@@ -88,7 +88,7 @@ class XVLAProcessor(ProcessorMixin):
|
|||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
|
|
||||||
# ================== LANGUAGE ENCODING ==================
|
# ================== LANGUAGE ENCODING ==================
|
||||||
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
|
def encode_language(self, language_instruction: str | list[str]) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Tokenize one or more language instructions.
|
Tokenize one or more language instructions.
|
||||||
|
|
||||||
@@ -117,11 +117,7 @@ class XVLAProcessor(ProcessorMixin):
|
|||||||
return {"input_ids": inputs["input_ids"]}
|
return {"input_ids": inputs["input_ids"]}
|
||||||
|
|
||||||
# ================== IMAGE ENCODING ==================
|
# ================== IMAGE ENCODING ==================
|
||||||
def encode_image(
|
def encode_image(self, images: list | list[list], **kwargs) -> dict[str, torch.Tensor]:
|
||||||
self,
|
|
||||||
images: Union[List, List[List]],
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Preprocess one or more sets of multi-view images.
|
Preprocess one or more sets of multi-view images.
|
||||||
|
|
||||||
@@ -157,8 +153,7 @@ class XVLAProcessor(ProcessorMixin):
|
|||||||
# Pad to self.num_views
|
# Pad to self.num_views
|
||||||
if V_exist < self.num_views:
|
if V_exist < self.num_views:
|
||||||
processed = torch.cat(
|
processed = torch.cat(
|
||||||
[processed,
|
[processed, processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
||||||
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -177,10 +172,10 @@ class XVLAProcessor(ProcessorMixin):
|
|||||||
# ================== COMBINED CALL ==================
|
# ================== COMBINED CALL ==================
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
images: Optional[Union[List, List[List]]] = None,
|
images: list | list[list] | None = None,
|
||||||
language_instruction: Optional[Union[str, List[str]]] = None,
|
language_instruction: str | list[str] | None = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Combine image and text encoding into a unified multimodal input.
|
Combine image and text encoding into a unified multimodal input.
|
||||||
|
|
||||||
@@ -202,7 +197,7 @@ class XVLAProcessor(ProcessorMixin):
|
|||||||
"image_mask": [B, num_views], optional
|
"image_mask": [B, num_views], optional
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
outputs: Dict[str, Any] = {}
|
outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
# Encode language if provided
|
# Encode language if provided
|
||||||
if language_instruction is not None:
|
if language_instruction is not None:
|
||||||
@@ -243,7 +238,9 @@ def make_xvla_pre_post_processors(
|
|||||||
padding_side=config.tokenizer_padding_side,
|
padding_side=config.tokenizer_padding_side,
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
|
NormalizerProcessorStep(
|
||||||
|
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
|
),
|
||||||
]
|
]
|
||||||
output_steps = [
|
output_steps = [
|
||||||
UnnormalizerProcessorStep(
|
UnnormalizerProcessorStep(
|
||||||
|
|||||||
@@ -17,17 +17,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Final, Iterable, Tuple
|
from typing import Final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------- Small utils ----------------------------------
|
# ------------------------------- Small utils ----------------------------------
|
||||||
|
|
||||||
def _to_2tuple(x) -> Tuple:
|
|
||||||
|
def _to_2tuple(x) -> tuple:
|
||||||
"""Minimal replacement for timm.layers.to_2tuple."""
|
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||||
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||||
t = tuple(x)
|
t = tuple(x)
|
||||||
@@ -42,6 +43,7 @@ def _has_sdp_attention() -> bool:
|
|||||||
|
|
||||||
# ---------------------------------- MLP --------------------------------------
|
# ---------------------------------- MLP --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
"""
|
"""
|
||||||
MLP used in ViT-style blocks.
|
MLP used in ViT-style blocks.
|
||||||
@@ -55,8 +57,8 @@ class Mlp(nn.Module):
|
|||||||
hidden_features: int | None = None,
|
hidden_features: int | None = None,
|
||||||
out_features: int | None = None,
|
out_features: int | None = None,
|
||||||
norm_layer: type[nn.Module] | None = None,
|
norm_layer: type[nn.Module] | None = None,
|
||||||
bias: bool | Tuple[bool, bool] = True,
|
bias: bool | tuple[bool, bool] = True,
|
||||||
drop: float | Tuple[float, float] = 0.0,
|
drop: float | tuple[float, float] = 0.0,
|
||||||
use_conv: bool = False,
|
use_conv: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -86,6 +88,7 @@ class Mlp(nn.Module):
|
|||||||
|
|
||||||
# -------------------------------- Attention ----------------------------------
|
# -------------------------------- Attention ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-Head Self-Attention with optional fused SDPA fallback.
|
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||||
@@ -143,7 +146,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||||
) # [B, H, T, Dh]
|
) # [B, H, T, Dh]
|
||||||
else:
|
else:
|
||||||
@@ -161,6 +166,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# ------------------------------- Utilities -----------------------------------
|
# ------------------------------- Utilities -----------------------------------
|
||||||
|
|
||||||
|
|
||||||
def basic_init(module: nn.Module) -> None:
|
def basic_init(module: nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
Apply a basic initialization scheme to Linear layers.
|
Apply a basic initialization scheme to Linear layers.
|
||||||
@@ -194,9 +200,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
|
|||||||
"""
|
"""
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(
|
freqs = torch.exp(
|
||||||
-math.log(max_period)
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
|
||||||
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
|
|
||||||
/ half
|
|
||||||
)
|
)
|
||||||
args = t[:, None] * freqs[None]
|
args = t[:, None] * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
@@ -207,6 +211,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
|
|||||||
|
|
||||||
# ------------------------------- Core Layers ----------------------------------
|
# ------------------------------- Core Layers ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
class DomainAwareLinear(nn.Module):
|
class DomainAwareLinear(nn.Module):
|
||||||
"""
|
"""
|
||||||
Linear layer with domain-conditioned parameters (per-sample).
|
Linear layer with domain-conditioned parameters (per-sample).
|
||||||
@@ -283,6 +288,7 @@ class TransformerBlock(nn.Module):
|
|||||||
|
|
||||||
# --------------------------- Main Model ---------------------------------------
|
# --------------------------- Main Model ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class SoftPromptedTransformer(nn.Module):
|
class SoftPromptedTransformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-modal, domain-aware Transformer with optional soft prompts.
|
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||||
@@ -318,7 +324,9 @@ class SoftPromptedTransformer(nn.Module):
|
|||||||
|
|
||||||
if use_hetero_proj:
|
if use_hetero_proj:
|
||||||
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||||
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
self.aux_visual_proj = DomainAwareLinear(
|
||||||
|
multi_modal_input_size, hidden_size, num_domains=num_domains
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
@@ -376,7 +384,11 @@ class SoftPromptedTransformer(nn.Module):
|
|||||||
# Project visual streams and concatenate
|
# Project visual streams and concatenate
|
||||||
if self.use_hetero_proj:
|
if self.use_hetero_proj:
|
||||||
x = torch.cat(
|
x = torch.cat(
|
||||||
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
|
[
|
||||||
|
x,
|
||||||
|
self.vlm_proj(vlm_features, domain_id),
|
||||||
|
self.aux_visual_proj(aux_visual_inputs, domain_id),
|
||||||
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -385,9 +397,7 @@ class SoftPromptedTransformer(nn.Module):
|
|||||||
# Add positional embeddings (truncate if needed)
|
# Add positional embeddings (truncate if needed)
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
if seq_len > self.pos_emb.shape[1]:
|
if seq_len > self.pos_emb.shape[1]:
|
||||||
raise ValueError(
|
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
|
||||||
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
|
|
||||||
)
|
|
||||||
x = x + self.pos_emb[:, :seq_len, :]
|
x = x + self.pos_emb[:, :seq_len, :]
|
||||||
|
|
||||||
# Append soft prompts
|
# Append soft prompts
|
||||||
|
|||||||
+1
-2
@@ -1,6 +1,5 @@
|
|||||||
from lerobot.policies.factory import make_policy
|
|
||||||
from lerobot.policies.factory import make_policy_config
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.policies.factory import make_policy, make_policy_config
|
||||||
|
|
||||||
cfg = make_policy_config("xvla")
|
cfg = make_policy_config("xvla")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ lerobot-train \
|
|||||||
--output_dir=outputs/train/act_your_dataset \
|
--output_dir=outputs/train/act_your_dataset \
|
||||||
--job_name=xvla_so101_pickplace \
|
--job_name=xvla_so101_pickplace \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
|
--policy.action_mode=franka_joint7 \
|
||||||
--wandb.enable=true \
|
--wandb.enable=true \
|
||||||
--policy.repo_id=jadechoghari/xvla_policy \
|
--policy.repo_id=jadechoghari/xvla_policy \
|
||||||
--steps=10000
|
--steps=10000
|
||||||
|
|||||||
Reference in New Issue
Block a user