first commit

This commit is contained in:
Jade Choghari
2025-11-07 11:54:46 +01:00
parent f6b16f6d97
commit d9e4d374c5
10 changed files with 4763 additions and 0 deletions
+2
View File
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
@@ -31,4 +32,5 @@ __all__ = [
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
]
+14
View File
@@ -39,6 +39,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -107,6 +108,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -150,6 +155,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -329,6 +336,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processing_xvla import make_xvla_pre_post_processors
processors = make_xvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
+275
View File
@@ -0,0 +1,275 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from typing import Iterable, Tuple, Dict, Type
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
return ACTION_REGISTRY[key](**kwargs)
# =============================================================================
# Base class
# =============================================================================
class BaseActionSpace(nn.Module):
"""
Abstract base class for all action-space definitions.
Each subclass defines:
- `dim_action`: dimension of the action vector.
- `gripper_idx`: indices of gripper channels.
- `compute_loss(pred, target)`: supervised loss for this space.
- `preprocess(proprio, action, mode)`: pre-step modifications.
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
"""
name: str = "base"
dim_action: int = 0
gripper_idx: Tuple[int, ...] = ()
def __init__(self):
super().__init__()
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
# ---------------------------------------------------------------------
# Space-level hooks
# ---------------------------------------------------------------------
def preprocess(
self,
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Default: return unchanged."""
return action
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= D]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
# =============================================================================
# Implementations
# =============================================================================
@register_action("ee6d")
class EE6DActionSpace(BaseActionSpace):
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 1.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("joint")
class JointActionSpace(BaseActionSpace):
"""Joint-space layout with joints + gripper only."""
dim_action = 14
gripper_idx = (6, 13)
GRIPPER_SCALE = 0.1
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("agibot_ee6d")
class AGIBOTEE6DActionSpace(BaseActionSpace):
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 10.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing applied in AGIBOT variant."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""AGIBOT does not postprocess."""
return action
# =============================================================================
# Exports
# =============================================================================
__all__ = [
"BaseActionSpace",
"build_action_space",
"register_action",
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"ACTION_REGISTRY",
]
@@ -0,0 +1,339 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
""" Florence-2 configuration"""
from typing import Optional
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
drop_path_rate (`float`, *optional*, defaults to 0.1):
The dropout rate of the drop path layer.
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
The patch size of the image.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
The patch stride of the image.
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
The patch padding of the image.
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
Whether to apply layer normalization before the patch embedding layer.
enable_checkpoint (`bool`, *optional*, defaults to False):
Whether to enable checkpointing.
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
The dimension of the embedding layer.
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of attention heads.
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of groups.
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
The depth of the model.
window_size (`int`, *optional*, defaults to 12):
The window size of the model.
projection_dim (`int`, *optional*, defaults to 1024):
The dimension of the projection layer.
visual_temporal_embedding (`dict`, *optional*):
The configuration of the visual temporal embedding.
image_pos_embed (`dict`, *optional*):
The configuration of the image position embedding.
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
The source of the image feature.
Example:
```python
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
>>> # Initializing a Florence2 Vision style configuration
>>> configuration = Florence2VisionConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "davit"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
drop_path_rate=0.1,
patch_size=[7, 3, 3, 3],
patch_stride=[4, 2, 2, 2],
patch_padding=[3, 1, 1, 1],
patch_prenorm=[False, True, True, True],
enable_checkpoint=False,
dim_embed=[256, 512, 1024, 2048],
num_heads=[8, 16, 32, 64],
num_groups=[8, 16, 32, 64],
depths=[1, 1, 9, 1],
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.patch_prenorm = patch_prenorm
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed
self.num_heads = num_heads
self.num_groups = num_groups
self.depths = depths
self.window_size = window_size
self.projection_dim = projection_dim
self.visual_temporal_embedding = visual_temporal_embedding
self.image_pos_embed = image_pos_embed
self.image_feature_source = image_feature_source
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the BART
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Florence2LanguageModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels (`int`, *optional*, defaults to 3):
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
>>> # Initializing a Florence2 Language style configuration
>>> configuration = Florence2LanguageConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2LangaugeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2_language"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=51289,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
num_labels=num_labels,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed."
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
Florence-2 model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Florence2VisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 1024):
Dimension of the multimodal projection space.
Example:
```python
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
>>> # Initializing a clip-like vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Bart config
>>> text_config = BartConfig()
>>> # Initializing a Florence-2 configuration
>>> configuration = Florence2Config(vision_config, text_config)
>>> # Initializing a model from the florence-2 configuration
>>> model = Florence2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
vocab_size=51289,
projection_dim=1024,
**kwargs,
):
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
if vision_config is not None:
vision_config = Florence2VisionConfig(**vision_config)
self.vision_config = vision_config
self.vocab_size = self.vocab_size
self.text_config = text_config
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)
@@ -0,0 +1,217 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
from .configuration_florence2 import Florence2Config
@PreTrainedConfig.register_subclass("xvla")
@dataclass
class XVLAConfig(PreTrainedConfig):
"""
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
plug into the LeRobot training stack.
The config mirrors the knobs exposed in the original XVLA repository but also
declares the input/output feature contract required by LeRobot.
"""
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 32
n_action_steps: int = 32
num_actions: int = 32
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Florence2 backbone and tokenizer configuration
florence_config: dict[str, Any] | Florence2Config = field(default_factory=dict)
tokenizer_name: str = "facebook/bart-large"
tokenizer_max_length: int = 64
tokenizer_padding_side: str = "right"
pad_language_to: str = "max_length"
# Transformer head
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512
use_hetero_proj: bool = False
# Action & proprioception
action_mode: str = "ee6d"
num_denoising_steps: int = 10
use_proprio: bool = True
max_state_dim: int = 32
domain_feature_key: str | None = None
# Vision preprocessing
resize_imgs_with_padding: tuple[int, int] | None = (518, 518)
num_image_views: int | None = None
empty_cameras: int = 0
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-4
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.num_actions <= 0:
raise ValueError("`num_actions` must be strictly positive.")
if self.chunk_size != self.num_actions:
self.chunk_size = self.num_actions
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
)
if isinstance(self.florence_config, Florence2Config):
self.florence_config = self.florence_config.to_dict()
if self.num_image_views is not None and self.num_image_views <= 0:
raise ValueError("`num_image_views` must be > 0 when specified.")
self._florence_config_obj: Florence2Config | None = None
def get_florence_config(self) -> Florence2Config:
"""
Build (and cache) the Florence2 transformer config that should back the VLM.
"""
if self._florence_config_obj is None:
if isinstance(self.florence_config, Florence2Config):
self._florence_config_obj = self.florence_config
else:
# TODO: jadechoghari: provide default way, and do not hardcode
# Ensure vision_config and text_config are provided with defaults if not specified
config_dict = dict(self.florence_config)
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
# Provide default vision config
config_dict['vision_config'] = {
'model_type': 'davit',
'drop_path_rate': 0.1,
'patch_size': [7, 3, 3, 3],
'patch_stride': [4, 2, 2, 2],
'patch_padding': [3, 1, 1, 1],
'patch_prenorm': [False, True, True, True],
'dim_embed': [256, 512, 1024, 2048],
'num_heads': [8, 16, 32, 64],
'num_groups': [8, 16, 32, 64],
'depths': [1, 1, 9, 1],
'window_size': 12,
'projection_dim': 1024,
'visual_temporal_embedding': {
'type': 'COSINE',
'max_temporal_embeddings': 100
},
'image_pos_embed': {
'type': 'learned_abs_2d',
'max_pos_embeddings': 50
},
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
}
if 'text_config' not in config_dict or config_dict['text_config'] is None:
# Provide default text config
config_dict['text_config'] = {
'model_type': 'florence2_language',
'vocab_size': 51289,
'd_model': 1024,
'encoder_layers': 12,
'decoder_layers': 12,
'encoder_attention_heads': 16,
'decoder_attention_heads': 16,
'encoder_ffn_dim': 4096,
'decoder_ffn_dim': 4096,
}
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("XVLA requires at least one visual feature in the inputs.")
if self.use_proprio and self.robot_state_feature is None:
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
if self.num_image_views is None:
self.num_image_views = len(self.image_features) + self.empty_cameras
else:
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
if self.empty_cameras > 0:
height, width = (480, 640)
if self.resize_imgs_with_padding is not None:
height, width = self.resize_imgs_with_padding
for idx in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{idx}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> list[int] | None:
return None
File diff suppressed because it is too large Load Diff
+388
View File
@@ -0,0 +1,388 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from collections import deque
from typing import Dict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .transformer import SoftPromptedTransformer
class XVLAModel(nn.Module):
"""
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
"""
def __init__(
self,
config: XVLAConfig,
florence_config: Florence2Config,
proprio_dim: int,
) -> None:
super().__init__()
self.config = config
self.num_actions: int = config.num_actions
self.use_proprio: bool = config.use_proprio
self.action_space = build_action_space(config.action_mode.lower())
self.dim_action = self.action_space.dim_action
self.dim_proprio = proprio_dim
self.vlm = Florence2ForConditionalGeneration(florence_config)
if hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
del lm.model.decoder
if hasattr(lm, "lm_head"):
del lm.lm_head
projection_dim = getattr(self.vlm.config, "projection_dim", None)
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
num_domains=config.num_domains,
dim_action=self.dim_action,
dim_propio=self.dim_proprio,
len_soft_prompts=config.len_soft_prompts,
dim_time=config.dim_time,
max_len_seq=config.max_len_seq,
use_hetero_proj=config.use_hetero_proj,
)
def forward_vlm(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
batch_size, num_views = pixel_values.shape[:2]
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
valid_images = flat_images[flat_mask]
valid_feats = self.vlm._encode_image(valid_images)
tokens_per_view, hidden_dim = valid_feats.shape[1:]
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
image_features[flat_mask] = valid_feats
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
image_features[:, 0],
inputs_embeds,
)
enc_out = self.vlm.language_model.model.encoder(
attention_mask=attention_mask,
inputs_embeds=merged_embeds,
)[0]
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
def forward(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> Dict[str, torch.Tensor]:
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
1 - 1e-5
)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
proprio=proprio_m,
**enc,
)
return self.action_space.compute_loss(pred_action, action)
@torch.no_grad()
def generate_actions(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
steps: int,
) -> torch.Tensor:
self.eval()
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
action_dim = self.dim_action
x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype)
action = torch.zeros_like(x1)
steps = max(1, int(steps))
for i in range(steps, 0, -1):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,
t=t,
**enc,
)
return self.action_space.postprocess(action)
class XVLAPolicy(PreTrainedPolicy):
"""LeRobot-compliant wrapper built around the XVLA model."""
config_class = XVLAConfig
name = "xvla"
def __init__(self, config: XVLAConfig):
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
proprio_dim = config.max_state_dim if config.use_proprio else 0
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset()
def reset(self) -> None:
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
return self.parameters()
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
if not self.config.use_proprio or OBS_STATE not in batch:
return torch.zeros(batch_size, 0, device=device)
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return pad_vector(state, self.model.dim_proprio)
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
present_img_keys = [key for key in self.config.image_features if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
"All image features are missing from the batch. "
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
)
images = []
masks = []
for key in present_img_keys:
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
images.append(img)
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
stacked_imgs = torch.stack(images, dim=1)
stacked_masks = torch.stack(masks, dim=1)
total_views = self.config.num_image_views or stacked_imgs.size(1)
total_views = max(total_views, stacked_imgs.size(1))
num_pad = total_views - stacked_imgs.size(1)
if num_pad > 0:
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
pad_imgs = stacked_imgs.new_zeros(pad_shape)
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
return stacked_imgs, stacked_masks
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
candidate = None
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
candidate = batch[self.config.domain_feature_key]
elif "domain_id" in batch:
candidate = batch["domain_id"]
if candidate is None:
return torch.zeros(batch_size, dtype=torch.long, device=device)
if not isinstance(candidate, torch.Tensor):
candidate = torch.as_tensor(candidate, device=device)
else:
candidate = candidate.to(device=device)
if candidate.ndim == 0:
candidate = candidate.expand(batch_size)
if candidate.ndim > 1:
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
if candidate.shape[0] != batch_size:
candidate = candidate.expand(batch_size)
return candidate.to(dtype=torch.long)
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
if ACTION not in batch:
raise ValueError("Batch is missing action targets required for training.")
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1)
if actions.shape[-1] != self.model.dim_action:
actions = pad_vector(actions, self.model.dim_action)
return actions
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
input_ids = batch[OBS_LANGUAGE_TOKENS]
batch_size = input_ids.shape[0]
images, image_mask = self._prepare_images(batch)
domain_id = self._get_domain_id(batch, batch_size, images.device)
proprio = self._prepare_state(batch, batch_size, images.device)
return {
"input_ids": input_ids,
"image_input": images,
"image_mask": image_mask,
"domain_id": domain_id,
"proprio": proprio,
}
def _trim_action_dim(self, actions: Tensor) -> Tensor:
feature = self.config.action_feature
if feature is None:
return actions
desired_dim = feature.shape[0]
if desired_dim == actions.shape[-1]:
return actions
if desired_dim < actions.shape[-1]:
return actions[..., :desired_dim]
return pad_vector(actions, desired_dim)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
inputs = self._build_model_inputs(batch)
targets = self._prepare_action_targets(batch)
losses = self.model(action=targets, **inputs)
total_loss = sum(losses.values())
log_dict = {k: v.detach().item() for k, v in losses.items()}
log_dict["loss"] = total_loss.detach().item()
return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
inputs = self._build_model_inputs(batch)
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
actions = self._trim_action_dim(actions)
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self._get_action_chunk(batch)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self._get_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
current_height, current_width = img.shape[2:]
if current_height == height and current_width == width:
return img
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
if vector.shape[-1] == new_dim:
return vector
if new_dim == 0:
shape = list(vector.shape)
shape[-1] = 0
return vector.new_zeros(*shape)
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = vector.new_zeros(*shape)
length = min(current_dim, new_dim)
new_vector[..., :length] = vector[..., :length]
return new_vector
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
current_len = tensor.size(dim)
if current_len == target_len:
return tensor
if current_len > target_len:
slices = [slice(None)] * tensor.dim()
slices[dim] = slice(0, target_len)
return tensor[tuple(slices)]
pad_shape = list(tensor.shape)
pad_shape[dim] = target_len - current_len
pad_tensor = tensor.new_zeros(pad_shape)
return torch.cat([tensor, pad_tensor], dim=dim)
@@ -0,0 +1,268 @@
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import ProcessorMixin
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class XVLAProcessor(ProcessorMixin):
"""
XVLAProcessor: Unified multimodal processor for XVLA models.
Handles:
- Multi-view image inputs (e.g., from multiple cameras).
- Batch processing for multiple samples.
- Joint tokenization and image tensor preparation.
This processor combines an image processor and a tokenizer under a single interface
so that users can call it directly like:
>>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
>>> inputs = processor(images=batch_images, language_instruction=batch_texts)
It is fully compatible with the Hugging Face AutoProcessor API.
Attributes
----------
num_views : int, default=3
Expected number of image views per sample. Missing views will be padded with zeros.
language_max_length : int, default=50
Maximum token length for text encoding.
attributes : list
Required by ProcessorMixin to know which submodules are stored and reloaded.
image_processor_class : str
The name of the associated image processor class.
tokenizer_class : tuple(str)
The names of compatible tokenizer classes.
"""
num_views: int = 3
language_max_length: int = 50
# Hugging Face ProcessorMixin-required metadata
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None):
"""
Initialize XVLAProcessor.
Parameters
----------
image_processor : PreTrainedImageProcessor, optional
The image processor used to normalize/resize images.
tokenizer : PreTrainedTokenizer, optional
The tokenizer used for text tokenization.
"""
# ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
super().__init__(image_processor, tokenizer)
# ================== LANGUAGE ENCODING ==================
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
"""
Tokenize one or more language instructions.
Parameters
----------
language_instruction : str or List[str]
A single instruction or a batch of instructions.
Returns
-------
Dict[str, torch.Tensor]
{
"input_ids": tensor of shape [B, L]
}
"""
if isinstance(language_instruction, str):
language_instruction = [language_instruction]
inputs = self.tokenizer(
language_instruction,
return_tensors="pt",
padding="max_length",
max_length=self.language_max_length,
truncation=True,
)
return {"input_ids": inputs["input_ids"]}
# ================== IMAGE ENCODING ==================
def encode_image(
self,
images: Union[List, List[List]],
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Preprocess one or more sets of multi-view images.
Parameters
----------
images : List or List[List]
Single sample: [img1, img2, ...]
Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
Each image may be a PIL.Image, NumPy array, or torch.Tensor.
kwargs : dict
Extra arguments passed to the underlying image processor
(e.g., `do_resize=False`, `size=(224,224)`).
Returns
-------
Dict[str, torch.Tensor]
{
"image_input": tensor [B, num_views, C, H, W],
"image_mask": tensor [B, num_views]
}
"""
# Normalize to batch form
if not isinstance(images[0], (list, tuple)):
images = [images] # convert single sample to batch of size 1
batch_imgs, batch_masks = [], []
for sample_imgs in images:
processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
V_exist = processed.size(0)
# Pad to self.num_views
if V_exist < self.num_views:
processed = torch.cat(
[processed,
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
dim=0,
)
# Mask: True for valid slots, False for padding
image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
image_mask[:V_exist] = True
batch_imgs.append(processed)
batch_masks.append(image_mask)
image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
return {"image_input": image_input, "image_mask": image_mask}
# ================== COMBINED CALL ==================
def __call__(
self,
images: Optional[Union[List, List[List]]] = None,
language_instruction: Optional[Union[str, List[str]]] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Combine image and text encoding into a unified multimodal input.
Parameters
----------
images : List or List[List], optional
Single-sample or batched multi-view images.
language_instruction : str or List[str], optional
Corresponding text instructions.
kwargs : dict
Extra args passed to image processor.
Returns
-------
Dict[str, torch.Tensor]
{
"input_ids": [B, L], optional,
"image_input": [B, num_views, C, H, W], optional,
"image_mask": [B, num_views], optional
}
"""
outputs: Dict[str, Any] = {}
# Encode language if provided
if language_instruction is not None:
outputs.update(self.encode_language(language_instruction))
# Encode image if provided
if images is not None:
outputs.update(self.encode_image(images, **kwargs))
# Sanity check for batch alignment
if "input_ids" in outputs and "image_input" in outputs:
assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
f"!= image batch {outputs['image_input'].size(0)}"
)
return outputs
def make_xvla_pre_post_processors(
config: XVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA.
"""
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.tokenizer_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
padding_side=config.tokenizer_padding_side,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+403
View File
@@ -0,0 +1,403 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import math
from functools import partial
from typing import Final, Iterable, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> Tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
return (x, x)
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(F, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | Tuple[bool, bool] = True,
drop: float | Tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = _to_2tuple(bias)
drop_probs = _to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
If PyTorch provides `scaled_dot_product_attention`, it will be used
(usually faster and more stable); otherwise we use a manual implementation.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape [B, T, C]
Input sequence.
Returns
-------
Tensor, shape [B, T, C]
Output sequence after MHSA + projection.
"""
B, T, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, T, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
)
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [B, H, T, Dh]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, H, T, Dh]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = self.proj(x)
x = self.proj_drop(x)
return x
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
- Weight: Xavier uniform initialization.
- Bias: Set to zero.
"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Parameters
----------
t : torch.Tensor
Shape [B]. Each element is a timestep index, may be fractional.
dim : int
Dimensionality of the output embedding.
max_period : int, default=100
Controls the minimum frequency of the sinusoids.
Returns
-------
torch.Tensor
Shape [B, dim]. Sinusoidal embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
/ half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
Each domain has its own weight and bias vectors, stored in embeddings.
"""
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc = nn.Embedding(num_domains, output_size * input_size)
self.bias = nn.Embedding(num_domains, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.bias.weight)
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor
[B, I] or [B, T, I]
domain_id : LongTensor
[B], domain indices.
Returns
-------
Tensor
[B, O] or [B, T, O]
"""
B = domain_id.shape[0]
squeeze_T = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_T = True
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
b = self.bias(domain_id).view(B, self.output_size)
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
if squeeze_T:
y = y.squeeze(1)
return y
class TransformerBlock(nn.Module):
"""
Standard Transformer block (pre-LN): LN MHSA residual, LN MLP residual.
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=0.1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, [B, T, H]
Returns
-------
Tensor, [B, T, H]
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
See parameter and forward I/O descriptions inside the docstrings.
"""
def __init__(
self,
hidden_size: int = 768,
multi_modal_input_size: int = 768,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
num_domains: int = 20,
dim_action: int = 20,
dim_propio: int = 20,
dim_time: int = 32,
len_soft_prompts: int = 32,
max_len_seq: int = 512,
use_hetero_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.dim_action = dim_action
self.dim_time = dim_time
self.len_soft_prompts = len_soft_prompts
self.use_hetero_proj = use_hetero_proj
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
)
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
nn.init.normal_(self.pos_emb, std=0.02)
self.norm = nn.LayerNorm(hidden_size)
self.action_encoder = DomainAwareLinear(
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
)
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
if len_soft_prompts > 0:
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
self.apply(basic_init)
def forward(
self,
domain_id: torch.LongTensor,
vlm_features: torch.Tensor,
aux_visual_inputs: torch.Tensor,
action_with_noise: torch.Tensor,
proprio: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass.
Inputs
------
domain_id : [B]
vlm_features : [B, T_vlm, D]
aux_visual_inputs : [B, T_aux, D]
action_with_noise : [B, T_action, dim_action]
proprio : [B, dim_propio]
t : [B]
Returns
-------
Tensor
Predicted actions, [B, T_action, dim_action]
"""
B, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
dim=1,
)
else:
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
)
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
for block in self.blocks:
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
+11
View File
@@ -0,0 +1,11 @@
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy_config
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
cfg = make_policy_config("xvla")
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
print(policy)