mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
first commit
This commit is contained in:
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
@@ -31,4 +32,5 @@ __all__ = [
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
"XVLAConfig",
|
||||
]
|
||||
|
||||
@@ -39,6 +39,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
@@ -107,6 +108,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
elif name == "xvla":
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
|
||||
return XVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -150,6 +155,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -329,6 +336,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processing_xvla import make_xvla_pre_post_processors
|
||||
|
||||
processors = make_xvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Iterable, Tuple, Dict, Type
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
|
||||
|
||||
|
||||
def register_action(name: str):
|
||||
"""Decorator for registering a new action space."""
|
||||
def _wrap(cls):
|
||||
key = name.lower()
|
||||
if key in ACTION_REGISTRY:
|
||||
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
||||
ACTION_REGISTRY[key] = cls
|
||||
cls.name = key
|
||||
return cls
|
||||
return _wrap
|
||||
|
||||
|
||||
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
|
||||
"""Instantiate a registered action space by name."""
|
||||
key = name.lower()
|
||||
if key not in ACTION_REGISTRY:
|
||||
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
||||
return ACTION_REGISTRY[key](**kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base class
|
||||
# =============================================================================
|
||||
class BaseActionSpace(nn.Module):
|
||||
"""
|
||||
Abstract base class for all action-space definitions.
|
||||
|
||||
Each subclass defines:
|
||||
- `dim_action`: dimension of the action vector.
|
||||
- `gripper_idx`: indices of gripper channels.
|
||||
- `compute_loss(pred, target)`: supervised loss for this space.
|
||||
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
||||
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
dim_action: int = 0
|
||||
gripper_idx: Tuple[int, ...] = ()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Core supervised loss
|
||||
# ---------------------------------------------------------------------
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Alias for compute_loss."""
|
||||
return self.compute_loss(pred, target)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Space-level hooks
|
||||
# ---------------------------------------------------------------------
|
||||
def preprocess(
|
||||
self,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
mode: str = "train",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Default: return unchanged."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Default: return unchanged."""
|
||||
return action
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utilities
|
||||
# =============================================================================
|
||||
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
|
||||
bad = [i for i in idx if i < 0 or i >= D]
|
||||
if bad:
|
||||
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Implementations
|
||||
# =============================================================================
|
||||
@register_action("ee6d")
|
||||
class EE6DActionSpace(BaseActionSpace):
|
||||
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 1.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
|
||||
# Gripper BCE
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
# XYZ position
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
|
||||
# Rotation 6D
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("joint")
|
||||
class JointActionSpace(BaseActionSpace):
|
||||
"""Joint-space layout with joints + gripper only."""
|
||||
|
||||
dim_action = 14
|
||||
gripper_idx = (6, 13)
|
||||
GRIPPER_SCALE = 0.1
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
|
||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("agibot_ee6d")
|
||||
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 10.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
|
||||
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
||||
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
||||
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""No preprocessing applied in AGIBOT variant."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""AGIBOT does not postprocess."""
|
||||
return action
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
__all__ = [
|
||||
"BaseActionSpace",
|
||||
"build_action_space",
|
||||
"register_action",
|
||||
"EE6DActionSpace",
|
||||
"JointActionSpace",
|
||||
"AGIBOTEE6DActionSpace",
|
||||
"ACTION_REGISTRY",
|
||||
]
|
||||
@@ -0,0 +1,339 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
class Florence2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||
The dropout rate of the drop path layer.
|
||||
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
||||
The patch size of the image.
|
||||
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
||||
The patch stride of the image.
|
||||
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
||||
The patch padding of the image.
|
||||
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
||||
Whether to apply layer normalization before the patch embedding layer.
|
||||
enable_checkpoint (`bool`, *optional*, defaults to False):
|
||||
Whether to enable checkpointing.
|
||||
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
||||
The dimension of the embedding layer.
|
||||
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of attention heads.
|
||||
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of groups.
|
||||
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
||||
The depth of the model.
|
||||
window_size (`int`, *optional*, defaults to 12):
|
||||
The window size of the model.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the projection layer.
|
||||
visual_temporal_embedding (`dict`, *optional*):
|
||||
The configuration of the visual temporal embedding.
|
||||
image_pos_embed (`dict`, *optional*):
|
||||
The configuration of the image position embedding.
|
||||
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
||||
The source of the image feature.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
||||
|
||||
>>> # Initializing a Florence2 Vision style configuration
|
||||
>>> configuration = Florence2VisionConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "davit"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
drop_path_rate=0.1,
|
||||
patch_size=[7, 3, 3, 3],
|
||||
patch_stride=[4, 2, 2, 2],
|
||||
patch_padding=[3, 1, 1, 1],
|
||||
patch_prenorm=[False, True, True, True],
|
||||
enable_checkpoint=False,
|
||||
dim_embed=[256, 512, 1024, 2048],
|
||||
num_heads=[8, 16, 32, 64],
|
||||
num_groups=[8, 16, 32, 64],
|
||||
depths=[1, 1, 9, 1],
|
||||
window_size=12,
|
||||
projection_dim=1024,
|
||||
visual_temporal_embedding=None,
|
||||
image_pos_embed=None,
|
||||
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
|
||||
**kwargs,
|
||||
):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
self.patch_padding = patch_padding
|
||||
self.patch_prenorm = patch_prenorm
|
||||
self.enable_checkpoint = enable_checkpoint
|
||||
self.dim_embed = dim_embed
|
||||
self.num_heads = num_heads
|
||||
self.num_groups = num_groups
|
||||
self.depths = depths
|
||||
self.window_size = window_size
|
||||
self.projection_dim = projection_dim
|
||||
self.visual_temporal_embedding = visual_temporal_embedding
|
||||
self.image_pos_embed = image_pos_embed
|
||||
self.image_feature_source = image_feature_source
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
|
||||
class Florence2LanguageConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the BART
|
||||
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
||||
d_model (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
num_labels (`int`, *optional*, defaults to 3):
|
||||
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
||||
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||
`eos_token_id`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
||||
|
||||
>>> # Initializing a Florence2 Language style configuration
|
||||
>>> configuration = Florence2LanguageConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2LangaugeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2_language"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=51289,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
num_labels=3,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
decoder_start_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||
"The config can simply be saved and uploaded again to be fixed."
|
||||
)
|
||||
|
||||
class Florence2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||
Florence-2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Florence2VisionConfig`, *optional*):
|
||||
Custom vision config or dict
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the multimodal projection space.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
||||
|
||||
>>> # Initializing a clip-like vision config
|
||||
>>> vision_config = CLIPVisionConfig()
|
||||
|
||||
>>> # Initializing a Bart config
|
||||
>>> text_config = BartConfig()
|
||||
|
||||
>>> # Initializing a Florence-2 configuration
|
||||
>>> configuration = Florence2Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the florence-2 configuration
|
||||
>>> model = Florence2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
vocab_size=51289,
|
||||
projection_dim=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.vocab_size = vocab_size
|
||||
self.projection_dim = projection_dim
|
||||
if vision_config is not None:
|
||||
vision_config = Florence2VisionConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
self.vocab_size = self.vocab_size
|
||||
|
||||
self.text_config = text_config
|
||||
if text_config is not None:
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
from .configuration_florence2 import Florence2Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("xvla")
|
||||
@dataclass
|
||||
class XVLAConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
|
||||
plug into the LeRobot training stack.
|
||||
|
||||
The config mirrors the knobs exposed in the original XVLA repository but also
|
||||
declares the input/output feature contract required by LeRobot.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32
|
||||
n_action_steps: int = 32
|
||||
num_actions: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Florence2 backbone and tokenizer configuration
|
||||
florence_config: dict[str, Any] | Florence2Config = field(default_factory=dict)
|
||||
tokenizer_name: str = "facebook/bart-large"
|
||||
tokenizer_max_length: int = 64
|
||||
tokenizer_padding_side: str = "right"
|
||||
pad_language_to: str = "max_length"
|
||||
|
||||
# Transformer head
|
||||
hidden_size: int = 1024
|
||||
depth: int = 24
|
||||
num_heads: int = 16
|
||||
mlp_ratio: float = 4.0
|
||||
num_domains: int = 30
|
||||
len_soft_prompts: int = 32
|
||||
dim_time: int = 32
|
||||
max_len_seq: int = 512
|
||||
use_hetero_proj: bool = False
|
||||
|
||||
# Action & proprioception
|
||||
action_mode: str = "ee6d"
|
||||
num_denoising_steps: int = 10
|
||||
use_proprio: bool = True
|
||||
max_state_dim: int = 32
|
||||
domain_feature_key: str | None = None
|
||||
|
||||
# Vision preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = (518, 518)
|
||||
num_image_views: int | None = None
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
if self.num_actions <= 0:
|
||||
raise ValueError("`num_actions` must be strictly positive.")
|
||||
if self.chunk_size != self.num_actions:
|
||||
self.chunk_size = self.num_actions
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||
)
|
||||
if isinstance(self.florence_config, Florence2Config):
|
||||
self.florence_config = self.florence_config.to_dict()
|
||||
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||
self._florence_config_obj: Florence2Config | None = None
|
||||
|
||||
def get_florence_config(self) -> Florence2Config:
|
||||
"""
|
||||
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||
"""
|
||||
if self._florence_config_obj is None:
|
||||
if isinstance(self.florence_config, Florence2Config):
|
||||
self._florence_config_obj = self.florence_config
|
||||
else:
|
||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||
config_dict = dict(self.florence_config)
|
||||
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
|
||||
# Provide default vision config
|
||||
config_dict['vision_config'] = {
|
||||
'model_type': 'davit',
|
||||
'drop_path_rate': 0.1,
|
||||
'patch_size': [7, 3, 3, 3],
|
||||
'patch_stride': [4, 2, 2, 2],
|
||||
'patch_padding': [3, 1, 1, 1],
|
||||
'patch_prenorm': [False, True, True, True],
|
||||
'dim_embed': [256, 512, 1024, 2048],
|
||||
'num_heads': [8, 16, 32, 64],
|
||||
'num_groups': [8, 16, 32, 64],
|
||||
'depths': [1, 1, 9, 1],
|
||||
'window_size': 12,
|
||||
'projection_dim': 1024,
|
||||
'visual_temporal_embedding': {
|
||||
'type': 'COSINE',
|
||||
'max_temporal_embeddings': 100
|
||||
},
|
||||
'image_pos_embed': {
|
||||
'type': 'learned_abs_2d',
|
||||
'max_pos_embeddings': 50
|
||||
},
|
||||
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
|
||||
}
|
||||
if 'text_config' not in config_dict or config_dict['text_config'] is None:
|
||||
# Provide default text config
|
||||
config_dict['text_config'] = {
|
||||
'model_type': 'florence2_language',
|
||||
'vocab_size': 51289,
|
||||
'd_model': 1024,
|
||||
'encoder_layers': 12,
|
||||
'decoder_layers': 12,
|
||||
'encoder_attention_heads': 16,
|
||||
'decoder_attention_heads': 16,
|
||||
'encoder_ffn_dim': 4096,
|
||||
'decoder_ffn_dim': 4096,
|
||||
}
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("XVLA requires at least one visual feature in the inputs.")
|
||||
if self.use_proprio and self.robot_state_feature is None:
|
||||
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
|
||||
if self.num_image_views is None:
|
||||
self.num_image_views = len(self.image_features) + self.empty_cameras
|
||||
else:
|
||||
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
|
||||
|
||||
if self.empty_cameras > 0:
|
||||
height, width = (480, 640)
|
||||
if self.resize_imgs_with_padding is not None:
|
||||
height, width = self.resize_imgs_with_padding
|
||||
for idx in range(self.empty_cameras):
|
||||
key = f"{OBS_IMAGES}.empty_camera_{idx}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, height, width),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .transformer import SoftPromptedTransformer
|
||||
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: XVLAConfig,
|
||||
florence_config: Florence2Config,
|
||||
proprio_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_actions: int = config.num_actions
|
||||
self.use_proprio: bool = config.use_proprio
|
||||
self.action_space = build_action_space(config.action_mode.lower())
|
||||
self.dim_action = self.action_space.dim_action
|
||||
self.dim_proprio = proprio_dim
|
||||
|
||||
self.vlm = Florence2ForConditionalGeneration(florence_config)
|
||||
if hasattr(self.vlm, "language_model"):
|
||||
lm = self.vlm.language_model
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
||||
del lm.model.decoder
|
||||
if hasattr(lm, "lm_head"):
|
||||
del lm.lm_head
|
||||
|
||||
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
||||
if projection_dim is None:
|
||||
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||
|
||||
self.transformer = SoftPromptedTransformer(
|
||||
hidden_size=config.hidden_size,
|
||||
multi_modal_input_size=projection_dim,
|
||||
depth=config.depth,
|
||||
num_heads=config.num_heads,
|
||||
mlp_ratio=config.mlp_ratio,
|
||||
num_domains=config.num_domains,
|
||||
dim_action=self.dim_action,
|
||||
dim_propio=self.dim_proprio,
|
||||
len_soft_prompts=config.len_soft_prompts,
|
||||
dim_time=config.dim_time,
|
||||
max_len_seq=config.max_len_seq,
|
||||
use_hetero_proj=config.use_hetero_proj,
|
||||
)
|
||||
|
||||
def forward_vlm(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Encode text and multi-view images via Florence2 encoder.
|
||||
"""
|
||||
batch_size, num_views = pixel_values.shape[:2]
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
|
||||
valid_images = flat_images[flat_mask]
|
||||
valid_feats = self.vlm._encode_image(valid_images)
|
||||
tokens_per_view, hidden_dim = valid_feats.shape[1:]
|
||||
|
||||
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||
image_features[flat_mask] = valid_feats
|
||||
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||
|
||||
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||
image_features[:, 0],
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
enc_out = self.vlm.language_model.model.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=merged_embeds,
|
||||
)[0]
|
||||
|
||||
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
|
||||
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
|
||||
1 - 1e-5
|
||||
)
|
||||
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
|
||||
pred_action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=action_noisy_m,
|
||||
t=t,
|
||||
proprio=proprio_m,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.compute_loss(pred_action, action)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
steps: int,
|
||||
) -> torch.Tensor:
|
||||
self.eval()
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
action_dim = self.dim_action
|
||||
|
||||
x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype)
|
||||
action = torch.zeros_like(x1)
|
||||
|
||||
steps = max(1, int(steps))
|
||||
for i in range(steps, 0, -1):
|
||||
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
|
||||
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||
action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=x_t_m,
|
||||
proprio=proprio_m,
|
||||
t=t,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.postprocess(action)
|
||||
|
||||
|
||||
class XVLAPolicy(PreTrainedPolicy):
|
||||
"""LeRobot-compliant wrapper built around the XVLA model."""
|
||||
|
||||
config_class = XVLAConfig
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
florence_config = config.get_florence_config()
|
||||
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||
return torch.zeros(batch_size, 0, device=device)
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
return pad_vector(state, self.model.dim_proprio)
|
||||
|
||||
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
"All image features are missing from the batch. "
|
||||
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
|
||||
)
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
|
||||
images.append(img)
|
||||
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
|
||||
|
||||
stacked_imgs = torch.stack(images, dim=1)
|
||||
stacked_masks = torch.stack(masks, dim=1)
|
||||
|
||||
total_views = self.config.num_image_views or stacked_imgs.size(1)
|
||||
total_views = max(total_views, stacked_imgs.size(1))
|
||||
num_pad = total_views - stacked_imgs.size(1)
|
||||
if num_pad > 0:
|
||||
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
|
||||
pad_imgs = stacked_imgs.new_zeros(pad_shape)
|
||||
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
|
||||
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
|
||||
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
|
||||
|
||||
return stacked_imgs, stacked_masks
|
||||
|
||||
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
candidate = None
|
||||
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
|
||||
candidate = batch[self.config.domain_feature_key]
|
||||
elif "domain_id" in batch:
|
||||
candidate = batch["domain_id"]
|
||||
|
||||
if candidate is None:
|
||||
return torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
if not isinstance(candidate, torch.Tensor):
|
||||
candidate = torch.as_tensor(candidate, device=device)
|
||||
else:
|
||||
candidate = candidate.to(device=device)
|
||||
|
||||
if candidate.ndim == 0:
|
||||
candidate = candidate.expand(batch_size)
|
||||
if candidate.ndim > 1:
|
||||
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
|
||||
if candidate.shape[0] != batch_size:
|
||||
candidate = candidate.expand(batch_size)
|
||||
return candidate.to(dtype=torch.long)
|
||||
|
||||
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
if ACTION not in batch:
|
||||
raise ValueError("Batch is missing action targets required for training.")
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1)
|
||||
if actions.shape[-1] != self.model.dim_action:
|
||||
actions = pad_vector(actions, self.model.dim_action)
|
||||
return actions
|
||||
|
||||
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
batch_size = input_ids.shape[0]
|
||||
images, image_mask = self._prepare_images(batch)
|
||||
domain_id = self._get_domain_id(batch, batch_size, images.device)
|
||||
proprio = self._prepare_state(batch, batch_size, images.device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"image_input": images,
|
||||
"image_mask": image_mask,
|
||||
"domain_id": domain_id,
|
||||
"proprio": proprio,
|
||||
}
|
||||
|
||||
def _trim_action_dim(self, actions: Tensor) -> Tensor:
|
||||
feature = self.config.action_feature
|
||||
if feature is None:
|
||||
return actions
|
||||
desired_dim = feature.shape[0]
|
||||
if desired_dim == actions.shape[-1]:
|
||||
return actions
|
||||
if desired_dim < actions.shape[-1]:
|
||||
return actions[..., :desired_dim]
|
||||
return pad_vector(actions, desired_dim)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
targets = self._prepare_action_targets(batch)
|
||||
losses = self.model(action=targets, **inputs)
|
||||
total_loss = sum(losses.values())
|
||||
|
||||
log_dict = {k: v.detach().item() for k, v in losses.items()}
|
||||
log_dict["loss"] = total_loss.detach().item()
|
||||
return total_loss, log_dict
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
actions = self._trim_action_dim(actions)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
return self._get_action_chunk(batch)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self._get_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
|
||||
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
|
||||
|
||||
current_height, current_width = img.shape[2:]
|
||||
if current_height == height and current_width == width:
|
||||
return img
|
||||
|
||||
ratio = max(current_width / width, current_height / height)
|
||||
resized_height = int(current_height / ratio)
|
||||
resized_width = int(current_width / ratio)
|
||||
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
|
||||
|
||||
pad_height = max(0, height - resized_height)
|
||||
pad_width = max(0, width - resized_width)
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
if new_dim == 0:
|
||||
shape = list(vector.shape)
|
||||
shape[-1] = 0
|
||||
return vector.new_zeros(*shape)
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = vector.new_zeros(*shape)
|
||||
length = min(current_dim, new_dim)
|
||||
new_vector[..., :length] = vector[..., :length]
|
||||
return new_vector
|
||||
|
||||
|
||||
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
|
||||
current_len = tensor.size(dim)
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len > target_len:
|
||||
slices = [slice(None)] * tensor.dim()
|
||||
slices[dim] = slice(0, target_len)
|
||||
return tensor[tuple(slices)]
|
||||
pad_shape = list(tensor.shape)
|
||||
pad_shape[dim] = target_len - current_len
|
||||
pad_tensor = tensor.new_zeros(pad_shape)
|
||||
return torch.cat([tensor, pad_tensor], dim=dim)
|
||||
@@ -0,0 +1,268 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class XVLAProcessor(ProcessorMixin):
|
||||
"""
|
||||
XVLAProcessor: Unified multimodal processor for XVLA models.
|
||||
|
||||
Handles:
|
||||
- Multi-view image inputs (e.g., from multiple cameras).
|
||||
- Batch processing for multiple samples.
|
||||
- Joint tokenization and image tensor preparation.
|
||||
|
||||
This processor combines an image processor and a tokenizer under a single interface
|
||||
so that users can call it directly like:
|
||||
|
||||
>>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
|
||||
>>> inputs = processor(images=batch_images, language_instruction=batch_texts)
|
||||
|
||||
It is fully compatible with the Hugging Face AutoProcessor API.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_views : int, default=3
|
||||
Expected number of image views per sample. Missing views will be padded with zeros.
|
||||
language_max_length : int, default=50
|
||||
Maximum token length for text encoding.
|
||||
attributes : list
|
||||
Required by ProcessorMixin to know which submodules are stored and reloaded.
|
||||
image_processor_class : str
|
||||
The name of the associated image processor class.
|
||||
tokenizer_class : tuple(str)
|
||||
The names of compatible tokenizer classes.
|
||||
"""
|
||||
|
||||
num_views: int = 3
|
||||
language_max_length: int = 50
|
||||
|
||||
# Hugging Face ProcessorMixin-required metadata
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
"""
|
||||
Initialize XVLAProcessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_processor : PreTrainedImageProcessor, optional
|
||||
The image processor used to normalize/resize images.
|
||||
tokenizer : PreTrainedTokenizer, optional
|
||||
The tokenizer used for text tokenization.
|
||||
"""
|
||||
# ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
# ================== LANGUAGE ENCODING ==================
|
||||
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Tokenize one or more language instructions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
language_instruction : str or List[str]
|
||||
A single instruction or a batch of instructions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, torch.Tensor]
|
||||
{
|
||||
"input_ids": tensor of shape [B, L]
|
||||
}
|
||||
"""
|
||||
if isinstance(language_instruction, str):
|
||||
language_instruction = [language_instruction]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
language_instruction,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.language_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
return {"input_ids": inputs["input_ids"]}
|
||||
|
||||
# ================== IMAGE ENCODING ==================
|
||||
def encode_image(
|
||||
self,
|
||||
images: Union[List, List[List]],
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Preprocess one or more sets of multi-view images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : List or List[List]
|
||||
Single sample: [img1, img2, ...]
|
||||
Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
|
||||
Each image may be a PIL.Image, NumPy array, or torch.Tensor.
|
||||
|
||||
kwargs : dict
|
||||
Extra arguments passed to the underlying image processor
|
||||
(e.g., `do_resize=False`, `size=(224,224)`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, torch.Tensor]
|
||||
{
|
||||
"image_input": tensor [B, num_views, C, H, W],
|
||||
"image_mask": tensor [B, num_views]
|
||||
}
|
||||
"""
|
||||
# Normalize to batch form
|
||||
if not isinstance(images[0], (list, tuple)):
|
||||
images = [images] # convert single sample to batch of size 1
|
||||
|
||||
batch_imgs, batch_masks = [], []
|
||||
|
||||
for sample_imgs in images:
|
||||
processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
|
||||
V_exist = processed.size(0)
|
||||
|
||||
# Pad to self.num_views
|
||||
if V_exist < self.num_views:
|
||||
processed = torch.cat(
|
||||
[processed,
|
||||
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Mask: True for valid slots, False for padding
|
||||
image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
|
||||
image_mask[:V_exist] = True
|
||||
|
||||
batch_imgs.append(processed)
|
||||
batch_masks.append(image_mask)
|
||||
|
||||
image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
|
||||
image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
|
||||
|
||||
return {"image_input": image_input, "image_mask": image_mask}
|
||||
|
||||
# ================== COMBINED CALL ==================
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[Union[List, List[List]]] = None,
|
||||
language_instruction: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Combine image and text encoding into a unified multimodal input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : List or List[List], optional
|
||||
Single-sample or batched multi-view images.
|
||||
language_instruction : str or List[str], optional
|
||||
Corresponding text instructions.
|
||||
kwargs : dict
|
||||
Extra args passed to image processor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, torch.Tensor]
|
||||
{
|
||||
"input_ids": [B, L], optional,
|
||||
"image_input": [B, num_views, C, H, W], optional,
|
||||
"image_mask": [B, num_views], optional
|
||||
}
|
||||
"""
|
||||
outputs: Dict[str, Any] = {}
|
||||
|
||||
# Encode language if provided
|
||||
if language_instruction is not None:
|
||||
outputs.update(self.encode_language(language_instruction))
|
||||
|
||||
# Encode image if provided
|
||||
if images is not None:
|
||||
outputs.update(self.encode_image(images, **kwargs))
|
||||
|
||||
# Sanity check for batch alignment
|
||||
if "input_ids" in outputs and "image_input" in outputs:
|
||||
assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
|
||||
f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
|
||||
f"!= image batch {outputs['image_input'].size(0)}"
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def make_xvla_pre_post_processors(
|
||||
config: XVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Build the LeRobot processor pipelines for XVLA.
|
||||
"""
|
||||
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.tokenizer_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding=config.pad_language_to,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,403 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Final, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ------------------------------- Small utils ----------------------------------
|
||||
|
||||
def _to_2tuple(x) -> Tuple:
|
||||
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||
t = tuple(x)
|
||||
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
|
||||
return (x, x)
|
||||
|
||||
|
||||
def _has_sdp_attention() -> bool:
|
||||
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
# ---------------------------------- MLP --------------------------------------
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
MLP used in ViT-style blocks.
|
||||
|
||||
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
norm_layer: type[nn.Module] | None = None,
|
||||
bias: bool | Tuple[bool, bool] = True,
|
||||
drop: float | Tuple[float, float] = 0.0,
|
||||
use_conv: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = _to_2tuple(bias)
|
||||
drop_probs = _to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = nn.GELU(approximate="tanh")
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
# -------------------------------- Attention ----------------------------------
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||
|
||||
If PyTorch provides `scaled_dot_product_attention`, it will be used
|
||||
(usually faster and more stable); otherwise we use a manual implementation.
|
||||
"""
|
||||
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = _has_sdp_attention()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor, shape [B, T, C]
|
||||
Input sequence.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor, shape [B, T, C]
|
||||
Output sequence after MHSA + projection.
|
||||
"""
|
||||
B, T, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
|
||||
)
|
||||
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
) # [B, H, T, Dh]
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v # [B, H, T, Dh]
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
# ------------------------------- Utilities -----------------------------------
|
||||
|
||||
def basic_init(module: nn.Module) -> None:
|
||||
"""
|
||||
Apply a basic initialization scheme to Linear layers.
|
||||
|
||||
- Weight: Xavier uniform initialization.
|
||||
- Bias: Set to zero.
|
||||
"""
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
|
||||
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : torch.Tensor
|
||||
Shape [B]. Each element is a timestep index, may be fractional.
|
||||
dim : int
|
||||
Dimensionality of the output embedding.
|
||||
max_period : int, default=100
|
||||
Controls the minimum frequency of the sinusoids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Shape [B, dim]. Sinusoidal embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
|
||||
/ half
|
||||
)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2 == 1:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
# ------------------------------- Core Layers ----------------------------------
|
||||
|
||||
class DomainAwareLinear(nn.Module):
|
||||
"""
|
||||
Linear layer with domain-conditioned parameters (per-sample).
|
||||
|
||||
Each domain has its own weight and bias vectors, stored in embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.fc = nn.Embedding(num_domains, output_size * input_size)
|
||||
self.bias = nn.Embedding(num_domains, output_size)
|
||||
nn.init.xavier_uniform_(self.fc.weight)
|
||||
nn.init.zeros_(self.bias.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
[B, I] or [B, T, I]
|
||||
domain_id : LongTensor
|
||||
[B], domain indices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
[B, O] or [B, T, O]
|
||||
"""
|
||||
B = domain_id.shape[0]
|
||||
squeeze_T = False
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
squeeze_T = True
|
||||
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
|
||||
b = self.bias(domain_id).view(B, self.output_size)
|
||||
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
|
||||
if squeeze_T:
|
||||
y = y.squeeze(1)
|
||||
return y
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size)
|
||||
self.norm2 = nn.LayerNorm(hidden_size)
|
||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=int(hidden_size * mlp_ratio),
|
||||
drop=0.1,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor, [B, T, H]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor, [B, T, H]
|
||||
"""
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
# --------------------------- Main Model ---------------------------------------
|
||||
|
||||
class SoftPromptedTransformer(nn.Module):
|
||||
"""
|
||||
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||
|
||||
See parameter and forward I/O descriptions inside the docstrings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
multi_modal_input_size: int = 768,
|
||||
depth: int = 24,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
num_domains: int = 20,
|
||||
dim_action: int = 20,
|
||||
dim_propio: int = 20,
|
||||
dim_time: int = 32,
|
||||
len_soft_prompts: int = 32,
|
||||
max_len_seq: int = 512,
|
||||
use_hetero_proj: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dim_action = dim_action
|
||||
self.dim_time = dim_time
|
||||
self.len_soft_prompts = len_soft_prompts
|
||||
self.use_hetero_proj = use_hetero_proj
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
|
||||
)
|
||||
|
||||
if use_hetero_proj:
|
||||
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||
else:
|
||||
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
|
||||
nn.init.normal_(self.pos_emb, std=0.02)
|
||||
|
||||
self.norm = nn.LayerNorm(hidden_size)
|
||||
self.action_encoder = DomainAwareLinear(
|
||||
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
|
||||
)
|
||||
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
|
||||
|
||||
if len_soft_prompts > 0:
|
||||
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
|
||||
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
|
||||
|
||||
self.apply(basic_init)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
domain_id: torch.LongTensor,
|
||||
vlm_features: torch.Tensor,
|
||||
aux_visual_inputs: torch.Tensor,
|
||||
action_with_noise: torch.Tensor,
|
||||
proprio: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Inputs
|
||||
------
|
||||
domain_id : [B]
|
||||
vlm_features : [B, T_vlm, D]
|
||||
aux_visual_inputs : [B, T_aux, D]
|
||||
action_with_noise : [B, T_action, dim_action]
|
||||
proprio : [B, dim_propio]
|
||||
t : [B]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
Predicted actions, [B, T_action, dim_action]
|
||||
"""
|
||||
B, num_actions = action_with_noise.shape[:2]
|
||||
|
||||
# Encode (action + proprio + time) → tokens
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
|
||||
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
|
||||
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
|
||||
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
|
||||
|
||||
# Project visual streams and concatenate
|
||||
if self.use_hetero_proj:
|
||||
x = torch.cat(
|
||||
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
|
||||
|
||||
# Add positional embeddings (truncate if needed)
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self.pos_emb.shape[1]:
|
||||
raise ValueError(
|
||||
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
|
||||
)
|
||||
x = x + self.pos_emb[:, :seq_len, :]
|
||||
|
||||
# Append soft prompts
|
||||
if self.len_soft_prompts > 0:
|
||||
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
|
||||
x = torch.cat([x, soft_prompts], dim=1)
|
||||
|
||||
# Transformer backbone
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
# Decode only the action segment
|
||||
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
|
||||
@@ -0,0 +1,11 @@
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
|
||||
cfg = make_policy_config("xvla")
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
|
||||
print(policy)
|
||||
Reference in New Issue
Block a user