mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
add franka action
This commit is contained in:
@@ -39,8 +39,8 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
|
||||
@@ -15,18 +15,21 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Iterable, Tuple, Dict, Type
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
|
||||
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||
|
||||
|
||||
def register_action(name: str):
|
||||
"""Decorator for registering a new action space."""
|
||||
|
||||
def _wrap(cls):
|
||||
key = name.lower()
|
||||
if key in ACTION_REGISTRY:
|
||||
@@ -34,10 +37,11 @@ def register_action(name: str):
|
||||
ACTION_REGISTRY[key] = cls
|
||||
cls.name = key
|
||||
return cls
|
||||
|
||||
return _wrap
|
||||
|
||||
|
||||
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
|
||||
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||
"""Instantiate a registered action space by name."""
|
||||
key = name.lower()
|
||||
if key not in ACTION_REGISTRY:
|
||||
@@ -62,7 +66,7 @@ class BaseActionSpace(nn.Module):
|
||||
|
||||
name: str = "base"
|
||||
dim_action: int = 0
|
||||
gripper_idx: Tuple[int, ...] = ()
|
||||
gripper_idx: tuple[int, ...] = ()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -70,10 +74,10 @@ class BaseActionSpace(nn.Module):
|
||||
# ---------------------------------------------------------------------
|
||||
# Core supervised loss
|
||||
# ---------------------------------------------------------------------
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""Alias for compute_loss."""
|
||||
return self.compute_loss(pred, target)
|
||||
|
||||
@@ -85,7 +89,7 @@ class BaseActionSpace(nn.Module):
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
mode: str = "train",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Default: return unchanged."""
|
||||
return proprio, action
|
||||
|
||||
@@ -137,14 +141,14 @@ class EE6DActionSpace(BaseActionSpace):
|
||||
|
||||
# XYZ position
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
|
||||
# Rotation 6D
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
@@ -236,14 +240,16 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
|
||||
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
gripper_loss = (
|
||||
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
)
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
@@ -261,6 +267,32 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
return action
|
||||
|
||||
|
||||
@register_action("franka_joint7")
|
||||
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||
"""Franka Panda joint-space: 7 joints, no gripper."""
|
||||
|
||||
dim_action = 7
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
B, T, D = pred.shape
|
||||
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""No preprocessing needed for 7 joint actions."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Return directly (no sigmoid since no gripper)."""
|
||||
return action
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
@@ -271,5 +303,6 @@ __all__ = [
|
||||
"EE6DActionSpace",
|
||||
"JointActionSpace",
|
||||
"AGIBOTEE6DActionSpace",
|
||||
"FrankaJoint7ActionSpace",
|
||||
"ACTION_REGISTRY",
|
||||
]
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,16 +11,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Florence2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||
@@ -118,7 +117,6 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
|
||||
class Florence2LanguageConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||
@@ -269,6 +267,7 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
"The config can simply be saved and uploaded again to be fixed."
|
||||
)
|
||||
|
||||
|
||||
class Florence2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||
@@ -335,5 +334,4 @@ class Florence2Config(PretrainedConfig):
|
||||
if text_config is not None:
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -124,43 +124,37 @@ class XVLAConfig(PreTrainedConfig):
|
||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||
config_dict = dict(self.florence_config)
|
||||
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
# Provide default vision config
|
||||
config_dict['vision_config'] = {
|
||||
'model_type': 'davit',
|
||||
'drop_path_rate': 0.1,
|
||||
'patch_size': [14, 7, 7, 7],
|
||||
'patch_stride': [4, 2, 2, 2],
|
||||
'patch_padding': [3, 1, 1, 1],
|
||||
'patch_prenorm': [False, True, True, True],
|
||||
'dim_embed': [256, 512, 1024, 2048],
|
||||
'num_heads': [8, 16, 32, 64],
|
||||
'num_groups': [8, 16, 32, 64],
|
||||
'depths': [1, 1, 9, 1],
|
||||
'window_size': 12,
|
||||
'projection_dim': 1024,
|
||||
'visual_temporal_embedding': {
|
||||
'type': 'COSINE',
|
||||
'max_temporal_embeddings': 100
|
||||
},
|
||||
'image_pos_embed': {
|
||||
'type': 'learned_abs_2d',
|
||||
'max_pos_embeddings': 50
|
||||
},
|
||||
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
|
||||
config_dict["vision_config"] = {
|
||||
"model_type": "davit",
|
||||
"drop_path_rate": 0.1,
|
||||
"patch_size": [14, 7, 7, 7],
|
||||
"patch_stride": [4, 2, 2, 2],
|
||||
"patch_padding": [3, 1, 1, 1],
|
||||
"patch_prenorm": [False, True, True, True],
|
||||
"dim_embed": [256, 512, 1024, 2048],
|
||||
"num_heads": [8, 16, 32, 64],
|
||||
"num_groups": [8, 16, 32, 64],
|
||||
"depths": [1, 1, 9, 1],
|
||||
"window_size": 12,
|
||||
"projection_dim": 1024,
|
||||
"visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100},
|
||||
"image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50},
|
||||
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"],
|
||||
}
|
||||
if 'text_config' not in config_dict or config_dict['text_config'] is None:
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
# Provide default text config
|
||||
config_dict['text_config'] = {
|
||||
'model_type': 'florence2_language',
|
||||
'vocab_size': 51289,
|
||||
'd_model': 1024,
|
||||
'encoder_layers': 12,
|
||||
'decoder_layers': 12,
|
||||
'encoder_attention_heads': 16,
|
||||
'decoder_attention_heads': 16,
|
||||
'encoder_ffn_dim': 4096,
|
||||
'decoder_ffn_dim': 4096,
|
||||
config_dict["text_config"] = {
|
||||
"model_type": "florence2_language",
|
||||
"vocab_size": 51289,
|
||||
"d_model": 1024,
|
||||
"encoder_layers": 12,
|
||||
"decoder_layers": 12,
|
||||
"encoder_attention_heads": 16,
|
||||
"decoder_attention_heads": 16,
|
||||
"encoder_ffn_dim": 4096,
|
||||
"decoder_ffn_dim": 4096,
|
||||
}
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
@@ -87,7 +86,7 @@ class XVLAModel(nn.Module):
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Encode text and multi-view images via Florence2 encoder.
|
||||
"""
|
||||
@@ -129,13 +128,14 @@ class XVLAModel(nn.Module):
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
|
||||
1 - 1e-5
|
||||
)
|
||||
t = (
|
||||
torch.rand(1, device=input_ids.device)
|
||||
+ torch.arange(batch_size, device=input_ids.device) / batch_size
|
||||
) % (1 - 1e-5)
|
||||
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
@@ -350,7 +350,9 @@ def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float
|
||||
ratio = max(current_width / width, current_height / height)
|
||||
resized_height = int(current_height / ratio)
|
||||
resized_width = int(current_width / ratio)
|
||||
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, height - resized_height)
|
||||
pad_width = max(0, width - resized_width)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
@@ -88,7 +88,7 @@ class XVLAProcessor(ProcessorMixin):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
# ================== LANGUAGE ENCODING ==================
|
||||
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
|
||||
def encode_language(self, language_instruction: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Tokenize one or more language instructions.
|
||||
|
||||
@@ -117,11 +117,7 @@ class XVLAProcessor(ProcessorMixin):
|
||||
return {"input_ids": inputs["input_ids"]}
|
||||
|
||||
# ================== IMAGE ENCODING ==================
|
||||
def encode_image(
|
||||
self,
|
||||
images: Union[List, List[List]],
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def encode_image(self, images: list | list[list], **kwargs) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Preprocess one or more sets of multi-view images.
|
||||
|
||||
@@ -157,8 +153,7 @@ class XVLAProcessor(ProcessorMixin):
|
||||
# Pad to self.num_views
|
||||
if V_exist < self.num_views:
|
||||
processed = torch.cat(
|
||||
[processed,
|
||||
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
||||
[processed, processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
@@ -177,10 +172,10 @@ class XVLAProcessor(ProcessorMixin):
|
||||
# ================== COMBINED CALL ==================
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[Union[List, List[List]]] = None,
|
||||
language_instruction: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
images: list | list[list] | None = None,
|
||||
language_instruction: str | list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Combine image and text encoding into a unified multimodal input.
|
||||
|
||||
@@ -202,7 +197,7 @@ class XVLAProcessor(ProcessorMixin):
|
||||
"image_mask": [B, num_views], optional
|
||||
}
|
||||
"""
|
||||
outputs: Dict[str, Any] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
# Encode language if provided
|
||||
if language_instruction is not None:
|
||||
@@ -243,7 +238,9 @@ def make_xvla_pre_post_processors(
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
|
||||
NormalizerProcessorStep(
|
||||
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
|
||||
@@ -17,17 +17,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Final, Iterable, Tuple
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ------------------------------- Small utils ----------------------------------
|
||||
|
||||
def _to_2tuple(x) -> Tuple:
|
||||
|
||||
def _to_2tuple(x) -> tuple:
|
||||
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||
t = tuple(x)
|
||||
@@ -42,6 +43,7 @@ def _has_sdp_attention() -> bool:
|
||||
|
||||
# ---------------------------------- MLP --------------------------------------
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
MLP used in ViT-style blocks.
|
||||
@@ -55,8 +57,8 @@ class Mlp(nn.Module):
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
norm_layer: type[nn.Module] | None = None,
|
||||
bias: bool | Tuple[bool, bool] = True,
|
||||
drop: float | Tuple[float, float] = 0.0,
|
||||
bias: bool | tuple[bool, bool] = True,
|
||||
drop: float | tuple[float, float] = 0.0,
|
||||
use_conv: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -86,6 +88,7 @@ class Mlp(nn.Module):
|
||||
|
||||
# -------------------------------- Attention ----------------------------------
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||
@@ -110,7 +113,7 @@ class Attention(nn.Module):
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = _has_sdp_attention()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
@@ -143,17 +146,19 @@ class Attention(nn.Module):
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
) # [B, H, T, Dh]
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
|
||||
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v # [B, H, T, Dh]
|
||||
x = attn @ v # [B, H, T, Dh]
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
|
||||
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@@ -161,6 +166,7 @@ class Attention(nn.Module):
|
||||
|
||||
# ------------------------------- Utilities -----------------------------------
|
||||
|
||||
|
||||
def basic_init(module: nn.Module) -> None:
|
||||
"""
|
||||
Apply a basic initialization scheme to Linear layers.
|
||||
@@ -194,9 +200,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
|
||||
/ half
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
|
||||
)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
@@ -207,6 +211,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
|
||||
|
||||
# ------------------------------- Core Layers ----------------------------------
|
||||
|
||||
|
||||
class DomainAwareLinear(nn.Module):
|
||||
"""
|
||||
Linear layer with domain-conditioned parameters (per-sample).
|
||||
@@ -283,6 +288,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
# --------------------------- Main Model ---------------------------------------
|
||||
|
||||
|
||||
class SoftPromptedTransformer(nn.Module):
|
||||
"""
|
||||
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||
@@ -318,7 +324,9 @@ class SoftPromptedTransformer(nn.Module):
|
||||
|
||||
if use_hetero_proj:
|
||||
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||
self.aux_visual_proj = DomainAwareLinear(
|
||||
multi_modal_input_size, hidden_size, num_domains=num_domains
|
||||
)
|
||||
else:
|
||||
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
@@ -367,16 +375,20 @@ class SoftPromptedTransformer(nn.Module):
|
||||
B, num_actions = action_with_noise.shape[:2]
|
||||
|
||||
# Encode (action + proprio + time) → tokens
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
|
||||
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
|
||||
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
|
||||
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
|
||||
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
|
||||
|
||||
# Project visual streams and concatenate
|
||||
if self.use_hetero_proj:
|
||||
x = torch.cat(
|
||||
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
|
||||
[
|
||||
x,
|
||||
self.vlm_proj(vlm_features, domain_id),
|
||||
self.aux_visual_proj(aux_visual_inputs, domain_id),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
@@ -385,9 +397,7 @@ class SoftPromptedTransformer(nn.Module):
|
||||
# Add positional embeddings (truncate if needed)
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self.pos_emb.shape[1]:
|
||||
raise ValueError(
|
||||
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
|
||||
)
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
|
||||
x = x + self.pos_emb[:, :seq_len, :]
|
||||
|
||||
# Append soft prompts
|
||||
|
||||
+1
-2
@@ -1,6 +1,5 @@
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
|
||||
cfg = make_policy_config("xvla")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user