add franka action

This commit is contained in:
Jade Choghari
2025-11-07 14:28:36 +01:00
parent 8a65623dec
commit 3cb14248a4
10 changed files with 508 additions and 458 deletions
+1 -1
View File
@@ -39,8 +39,8 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
+49 -16
View File
@@ -15,18 +15,21 @@
# ------------------------------------------------------------------------------
from __future__ import annotations
from typing import Iterable, Tuple, Dict, Type
from collections.abc import Iterable
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
@@ -34,10 +37,11 @@ def register_action(name: str):
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
@@ -62,7 +66,7 @@ class BaseActionSpace(nn.Module):
name: str = "base"
dim_action: int = 0
gripper_idx: Tuple[int, ...] = ()
gripper_idx: tuple[int, ...] = ()
def __init__(self):
super().__init__()
@@ -70,10 +74,10 @@ class BaseActionSpace(nn.Module):
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
@@ -85,7 +89,7 @@ class BaseActionSpace(nn.Module):
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
@@ -137,14 +141,14 @@ class EE6DActionSpace(BaseActionSpace):
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
@@ -236,14 +240,16 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
)
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
@@ -261,6 +267,32 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
return action
@register_action("franka_joint7")
class FrankaJoint7ActionSpace(BaseActionSpace):
"""Franka Panda joint-space: 7 joints, no gripper."""
dim_action = 7
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
return {"joints_loss": joints_loss}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing needed for 7 joint actions."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Return directly (no sigmoid since no gripper)."""
return action
# =============================================================================
# Exports
# =============================================================================
@@ -271,5 +303,6 @@ __all__ = [
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"FrankaJoint7ActionSpace",
"ACTION_REGISTRY",
]
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,16 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
""" Florence-2 configuration"""
from typing import Optional
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
@@ -118,7 +117,6 @@ class Florence2VisionConfig(PretrainedConfig):
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
@@ -269,6 +267,7 @@ class Florence2LanguageConfig(PretrainedConfig):
"The config can simply be saved and uploaded again to be fixed."
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
@@ -335,5 +334,4 @@ class Florence2Config(PretrainedConfig):
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)
+28 -34
View File
@@ -124,43 +124,37 @@ class XVLAConfig(PreTrainedConfig):
# TODO: jadechoghari: provide default way, and do not hardcode
# Ensure vision_config and text_config are provided with defaults if not specified
config_dict = dict(self.florence_config)
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
# Provide default vision config
config_dict['vision_config'] = {
'model_type': 'davit',
'drop_path_rate': 0.1,
'patch_size': [14, 7, 7, 7],
'patch_stride': [4, 2, 2, 2],
'patch_padding': [3, 1, 1, 1],
'patch_prenorm': [False, True, True, True],
'dim_embed': [256, 512, 1024, 2048],
'num_heads': [8, 16, 32, 64],
'num_groups': [8, 16, 32, 64],
'depths': [1, 1, 9, 1],
'window_size': 12,
'projection_dim': 1024,
'visual_temporal_embedding': {
'type': 'COSINE',
'max_temporal_embeddings': 100
},
'image_pos_embed': {
'type': 'learned_abs_2d',
'max_pos_embeddings': 50
},
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
config_dict["vision_config"] = {
"model_type": "davit",
"drop_path_rate": 0.1,
"patch_size": [14, 7, 7, 7],
"patch_stride": [4, 2, 2, 2],
"patch_padding": [3, 1, 1, 1],
"patch_prenorm": [False, True, True, True],
"dim_embed": [256, 512, 1024, 2048],
"num_heads": [8, 16, 32, 64],
"num_groups": [8, 16, 32, 64],
"depths": [1, 1, 9, 1],
"window_size": 12,
"projection_dim": 1024,
"visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100},
"image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50},
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"],
}
if 'text_config' not in config_dict or config_dict['text_config'] is None:
if "text_config" not in config_dict or config_dict["text_config"] is None:
# Provide default text config
config_dict['text_config'] = {
'model_type': 'florence2_language',
'vocab_size': 51289,
'd_model': 1024,
'encoder_layers': 12,
'decoder_layers': 12,
'encoder_attention_heads': 16,
'decoder_attention_heads': 16,
'encoder_ffn_dim': 4096,
'decoder_ffn_dim': 4096,
config_dict["text_config"] = {
"model_type": "florence2_language",
"vocab_size": 51289,
"d_model": 1024,
"encoder_layers": 12,
"decoder_layers": 12,
"encoder_attention_heads": 16,
"decoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"decoder_ffn_dim": 4096,
}
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
File diff suppressed because it is too large Load Diff
+9 -7
View File
@@ -19,7 +19,6 @@
from __future__ import annotations
from collections import deque
from typing import Dict
import torch
import torch.nn.functional as F # noqa: N812
@@ -87,7 +86,7 @@ class XVLAModel(nn.Module):
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
@@ -129,13 +128,14 @@ class XVLAModel(nn.Module):
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
1 - 1e-5
)
t = (
torch.rand(1, device=input_ids.device)
+ torch.arange(batch_size, device=input_ids.device) / batch_size
) % (1 - 1e-5)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
@@ -350,7 +350,9 @@ def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)
+12 -15
View File
@@ -14,7 +14,7 @@
# limitations under the License.
# ------------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, Union
from typing import Any
import torch
from transformers import ProcessorMixin
@@ -88,7 +88,7 @@ class XVLAProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer)
# ================== LANGUAGE ENCODING ==================
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
def encode_language(self, language_instruction: str | list[str]) -> dict[str, torch.Tensor]:
"""
Tokenize one or more language instructions.
@@ -117,11 +117,7 @@ class XVLAProcessor(ProcessorMixin):
return {"input_ids": inputs["input_ids"]}
# ================== IMAGE ENCODING ==================
def encode_image(
self,
images: Union[List, List[List]],
**kwargs
) -> Dict[str, torch.Tensor]:
def encode_image(self, images: list | list[list], **kwargs) -> dict[str, torch.Tensor]:
"""
Preprocess one or more sets of multi-view images.
@@ -157,8 +153,7 @@ class XVLAProcessor(ProcessorMixin):
# Pad to self.num_views
if V_exist < self.num_views:
processed = torch.cat(
[processed,
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
[processed, processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
dim=0,
)
@@ -177,10 +172,10 @@ class XVLAProcessor(ProcessorMixin):
# ================== COMBINED CALL ==================
def __call__(
self,
images: Optional[Union[List, List[List]]] = None,
language_instruction: Optional[Union[str, List[str]]] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
images: list | list[list] | None = None,
language_instruction: str | list[str] | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
Combine image and text encoding into a unified multimodal input.
@@ -202,7 +197,7 @@ class XVLAProcessor(ProcessorMixin):
"image_mask": [B, num_views], optional
}
"""
outputs: Dict[str, Any] = {}
outputs: dict[str, Any] = {}
# Encode language if provided
if language_instruction is not None:
@@ -243,7 +238,9 @@ def make_xvla_pre_post_processors(
padding_side=config.tokenizer_padding_side,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
output_steps = [
UnnormalizerProcessorStep(
+30 -20
View File
@@ -17,17 +17,18 @@
from __future__ import annotations
import math
from collections.abc import Iterable
from functools import partial
from typing import Final, Iterable, Tuple
from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> Tuple:
def _to_2tuple(x) -> tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
@@ -42,6 +43,7 @@ def _has_sdp_attention() -> bool:
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
@@ -55,8 +57,8 @@ class Mlp(nn.Module):
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | Tuple[bool, bool] = True,
drop: float | Tuple[float, float] = 0.0,
bias: bool | tuple[bool, bool] = True,
drop: float | tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
@@ -86,6 +88,7 @@ class Mlp(nn.Module):
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
@@ -110,7 +113,7 @@ class Attention(nn.Module):
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
@@ -143,17 +146,19 @@ class Attention(nn.Module):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [B, H, T, Dh]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, H, T, Dh]
x = attn @ v # [B, H, T, Dh]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = self.proj(x)
x = self.proj_drop(x)
return x
@@ -161,6 +166,7 @@ class Attention(nn.Module):
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
@@ -194,9 +200,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
/ half
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -207,6 +211,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
@@ -283,6 +288,7 @@ class TransformerBlock(nn.Module):
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
@@ -318,7 +324,9 @@ class SoftPromptedTransformer(nn.Module):
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(
multi_modal_input_size, hidden_size, num_domains=num_domains
)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
@@ -367,16 +375,20 @@ class SoftPromptedTransformer(nn.Module):
B, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
[
x,
self.vlm_proj(vlm_features, domain_id),
self.aux_visual_proj(aux_visual_inputs, domain_id),
],
dim=1,
)
else:
@@ -385,9 +397,7 @@ class SoftPromptedTransformer(nn.Module):
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
)
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
+1 -2
View File
@@ -1,6 +1,5 @@
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy_config
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_policy_config
cfg = make_policy_config("xvla")
+1
View File
@@ -4,6 +4,7 @@ lerobot-train \
--output_dir=outputs/train/act_your_dataset \
--job_name=xvla_so101_pickplace \
--policy.device=cuda \
--policy.action_mode=franka_joint7 \
--wandb.enable=true \
--policy.repo_id=jadechoghari/xvla_policy \
--steps=10000