major pre-commit cleanup

This commit is contained in:
Jade Choghari
2025-11-17 14:30:56 +01:00
parent 858626dea5
commit 42d615b69d
12 changed files with 180 additions and 559 deletions
-57
View File
@@ -1,57 +0,0 @@
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
# append 3 random image to a list
def make_random_pil_images(num_images=3, H=480, W=640):
images = []
for _ in range(num_images):
# Random RGB image
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
img = Image.fromarray(arr)
images.append(img)
return images
# Example:
images = make_random_pil_images()
language_instruction = "This is a random image"
# Multimodal preprocessing by processor
inputs = processor(images, language_instruction)
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
raise ValueError("Processor did not return the expected keys.")
proprio = torch.randn(1, 20)
domain_id = torch.tensor([0], dtype=torch.long)
# Align to model's device/dtype
device = model.device
dtype = next(model.parameters()).dtype
def to_model(t: torch.Tensor) -> torch.Tensor:
if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t)
# cast floats to model dtype, keep integral/bool as-is
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
inputs = {k: to_model(v) for k, v in inputs.items()}
inputs.update(
{
"proprio": to_model(proprio),
"domain_id": domain_id.to(device),
}
)
# Inference
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
breakpoint()
-63
View File
@@ -1,63 +0,0 @@
import numpy as np
import torch
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
observation_height: int = 360
observation_width: int = 360
# create an observation dict
OBS = {
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
"task": "put the object in the box",
}
def fake_rgb(H, W):
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
return img
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
# observation = preprocessor(OBS)
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
breakpoint()
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
env_cfg = make_env_config("libero", task="libero_spatial")
policy = make_policy(
cfg=cfg,
env_cfg=env_cfg,
)
policy.eval()
preprocessor_overrides = {
"device_processor": {"device": str(cfg.device)},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path=cfg.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
observation = preprocessor(OBS)
action = policy.select_action(observation)
target_eef = action[:, :3].to("cpu").numpy()
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
target_act = action[:, 9:10].to("cpu").numpy()
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
-237
View File
@@ -1,237 +0,0 @@
import os
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_policy_config
cfg = make_policy_config("xvla")
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
for name, param in policy.state_dict().items():
print(name, param.shape)
# now let's load in safetensors
import safetensors.torch
from huggingface_hub import snapshot_download
cache_dir = snapshot_download(
repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model"
)
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
# policy.load_state_dict(state_dict)
# 3. Add "model." prefix to every key
new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
keys_to_skip = [
"model.transformer.action_encoder.fc.weight",
"model.transformer.action_encoder.fc.bias",
]
new_state_dict = {k: v for k, v in new_state_dict.items() if k not in keys_to_skip}
# 4. Load into your model
missing, unexpected = policy.load_state_dict(new_state_dict, strict=False)
print("missing keys:", missing)
print()
print("unexpected keys:", unexpected)
import random
import numpy as np
import torch
from xvla.models.modeling_xvla import XVLA
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
observation_height: int = 224
observation_width: int = 224 # todo: jadechoghari, image size is different for the two models
# create an observation dict
OBS = {
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string
"task": "put the object in the box",
}
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def fake_rgb(H, W):
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
t = t.unsqueeze(0).float()
# normalize pixel to imagenet
return t
OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width)
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
env_cfg = make_env_config("libero", task="libero_spatial")
policy = make_policy(
cfg=cfg,
env_cfg=env_cfg,
)
policy.eval()
preprocessor_overrides = {
"device_processor": {"device": str(cfg.device)},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path=cfg.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
observation = preprocessor(OBS)
inputs = policy._build_model_inputs(observation)
#### now the og model ###########################################################
from xvla.models.processing_xvla import XVLAProcessor
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
domain_id = torch.tensor([3], dtype=torch.long)
inputs.update(
{
"proprio": OBS[OBS_STATE].to("cuda"),
"domain_id": domain_id.to("cuda"),
}
)
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
a = inputs[k]
b = inputs_1[k].to("cuda")
print(f"\n🔎 Key: {k}")
# Check shape
print(" shape:", a.shape, b.shape)
# Check if close
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
print(" ✔️ tensors are equal (allclose)")
else:
diff = torch.abs(a - b)
print(" ❌ tensors differ")
print(" max diff:", diff.max().item())
print(" mean diff:", diff.mean().item())
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
model.eval()
model.to("cuda")
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
# (Pdb) inputs['input_ids'].shape
# torch.Size([1, 64])
# (Pdb) inputs_1['input_ids'].shape
# torch.Size([1, 50])
# (Pdb) [0, 0, :, :4, 0]
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
# np all close
print(np.allclose(action, action_1, atol=1e-2, rtol=1e-2))
print("max diff:", np.max(np.abs(action - action_1)))
print("mean diff:", np.mean(np.abs(action - action_1)))
import random
import numpy as np
import torch
from PIL import Image
from xvla.models.configuration_xvla import XVLAConfig
from xvla.models.modeling_xvla import XVLA
from xvla.models.processor_xvla import XVLAProcessor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
model.eval()
model.to("cuda")
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero")
# /raid/jade/models/xvla-libero
# seet seed
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
def make_random_pil_images(num_images=3, H=480, W=640):
images = []
for _ in range(num_images):
# Random RGB image
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
img = Image.fromarray(arr)
images.append(img)
return images
# Example:
images = make_random_pil_images()
language_instruction = "This is a random image"
# Multimodal preprocessing by processor
inputs = processor(images, language_instruction)
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
raise ValueError("Processor did not return the expected keys.")
proprio = torch.randn(1, 20)
domain_id = torch.tensor([0], dtype=torch.long)
# Align to model's device/dtype
device = model.device
dtype = next(model.parameters()).dtype
def to_model(t: torch.Tensor) -> torch.Tensor:
if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t)
# cast floats to model dtype, keep integral/bool as-is
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
inputs = {k: to_model(v) for k, v in inputs.items()}
inputs.update(
{
"proprio": to_model(proprio),
"domain_id": domain_id.to(device),
}
)
# Inference
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
#### now for lerobot model #####################################################
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
env_cfg = make_env_config("libero", task="libero_spatial")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
policy = make_policy(cfg=cfg, env_cfg=env_cfg)
policy.eval()
policy.to("cuda")
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
+11 -11
View File
@@ -101,10 +101,10 @@ class BaseActionSpace(nn.Module):
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= D]
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= dim_action]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
# =============================================================================
@@ -132,8 +132,8 @@ class EE6DActionSpace(BaseActionSpace):
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
@@ -188,13 +188,13 @@ class JointActionSpace(BaseActionSpace):
def compute_loss(self, pred, target):
assert pred.shape == target.shape
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
return {
@@ -237,8 +237,8 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
def compute_loss(self, pred, target):
assert pred.shape == target.shape
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
@@ -280,7 +280,7 @@ class FrankaJoint7ActionSpace(BaseActionSpace):
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
batch_size, seq_len, action_dim = pred.shape
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
return {"joints_loss": joints_loss}
@@ -12,12 +12,11 @@
# limitations under the License.
import warnings
""" Florence-2 configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
""" Florence-2 configuration"""
logger = logging.get_logger(__name__)
@@ -82,37 +81,54 @@ class Florence2VisionConfig(PretrainedConfig):
def __init__(
self,
drop_path_rate=0.1,
patch_size=[7, 3, 3, 3],
patch_stride=[4, 2, 2, 2],
patch_padding=[3, 1, 1, 1],
patch_prenorm=[False, True, True, True],
patch_size=None,
patch_stride=None,
patch_padding=None,
patch_prenorm=None,
enable_checkpoint=False,
dim_embed=[256, 512, 1024, 2048],
num_heads=[8, 16, 32, 64],
num_groups=[8, 16, 32, 64],
depths=[1, 1, 9, 1],
dim_embed=None,
num_heads=None,
num_groups=None,
depths=None,
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
image_feature_source=None,
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.patch_prenorm = patch_prenorm
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed
self.num_heads = num_heads
self.num_groups = num_groups
self.depths = depths
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
self.depths = depths if depths is not None else [1, 1, 9, 1]
self.window_size = window_size
self.projection_dim = projection_dim
if visual_temporal_embedding is None:
visual_temporal_embedding = {
"type": "COSINE",
"max_temporal_embeddings": 100,
}
self.visual_temporal_embedding = visual_temporal_embedding
if image_pos_embed is None:
image_pos_embed = {
"type": "learned_abs_2d",
"max_pos_embeddings": 1000,
}
self.image_pos_embed = image_pos_embed
self.image_feature_source = image_feature_source
self.image_feature_source = (
image_feature_source
if image_feature_source is not None
else ["spatial_avg_pool", "temporal_avg_pool"]
)
super().__init__(**kwargs)
@@ -264,7 +280,8 @@ class Florence2LanguageConfig(PretrainedConfig):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed."
"The config can simply be saved and uploaded again to be fixed.",
stacklevel=2,
)
@@ -116,54 +116,12 @@ class XVLAConfig(PreTrainedConfig):
Build (and cache) the Florence2 transformer config that should back the VLM.
"""
if self._florence_config_obj is None:
# TODO: jadechoghari: provide default way, and do not hardcode
# Ensure vision_config and text_config are provided with defaults if not specified
config_dict = dict(self.florence_config)
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
raise ValueError("vision_config is required")
# # Provide default vision config
# config_dict["vision_config"] = {
# "model_type": "davit",
# "drop_path_rate": 0.1,
# "patch_size": [7, 3, 3, 3],
# "patch_stride": [4, 2, 2, 2],
# "patch_padding": [3, 1, 1, 1],
# "patch_prenorm": [False, True, True, True],
# "enable_checkpoint": False,
# "dim_embed": [256, 512, 1024, 2048],
# "num_heads": [8, 16, 32, 64],
# "num_groups": [8, 16, 32, 64],
# "depths": [1, 1, 9, 1],
# "window_size": 12,
# "projection_dim": 1024,
# "visual_temporal_embedding": {
# "type": "COSINE",
# "max_temporal_embeddings": 100
# },
# "image_pos_embed": {
# "type": "learned_abs_2d",
# "max_pos_embeddings": 50
# },
# "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
# }
if "text_config" not in config_dict or config_dict["text_config"] is None:
raise ValueError("text_config is required")
# # Provide default text config
# config_dict["text_config"] = {
# "vocab_size": 51289,
# "activation_dropout": 0.1,
# "activation_function": "gelu",
# "attention_dropout": 0.1,
# "d_model": 1024,
# "decoder_attention_heads": 16,
# "decoder_layers": 12,
# "encoder_attention_heads": 16,
# "encoder_layers": 12,
# "dropout": 0.1,
# "max_position_embeddings": 4096,
# "num_hidden_layers": 12,
# "num_beams": 3
# }
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
+90 -85
View File
@@ -19,7 +19,7 @@ from collections import OrderedDict
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torch.nn.functional as functional
import torch.utils.checkpoint
import torch.utils.checkpoint as checkpoint
from einops import rearrange
@@ -190,10 +190,7 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
inputs = module(*inputs) if isinstance(inputs, tuple) else module(inputs)
return inputs
@@ -206,7 +203,7 @@ class PreNorm(nn.Module):
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm != None:
if self.norm is not None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
@@ -259,11 +256,11 @@ class DepthWiseConv2d(nn.Module):
)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
batch_size, num_tokens, channels = x.shape
height, width = size
assert num_tokens == height * width
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
x = self.dw(x.transpose(1, 2).view(batch_size, channels, height, width))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
@@ -286,20 +283,20 @@ class ConvEmbed(nn.Module):
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
height, width = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
x = rearrange(x, "b (h w) c -> b c h w", h=height, w=width)
x = self.proj(x)
_, _, H, W = x.shape
_, _, height, width = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
return x, (height, width)
class ChannelAttention(nn.Module):
@@ -311,16 +308,20 @@ class ChannelAttention(nn.Module):
self.proj = nn.Linear(dim, dim)
def forward(self, x, size):
B, N, C = x.shape
batch_size, num_tokens, channels = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
qkv = (
self.qkv(x)
.reshape(batch_size, num_tokens, 3, self.groups, channels // self.groups)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (float(N) ** -0.5)
q = q * (float(num_tokens) ** -0.5)
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return x, size
@@ -366,18 +367,17 @@ class ChannelBlock(nn.Module):
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
batch_size, height, width, channels = x.shape
x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
return windows
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
B = batch_size
def window_reverse(windows, batch_size: int, window_size: int, height: int, width: int):
# this will cause onnx conversion failed for dynamic axis, because treated as constant
# int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
# int(windows.shape[0] / (height * width / window_size / window_size))
x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return x
@@ -396,43 +396,47 @@ class WindowAttention(nn.Module):
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
height, width = size
batch_size, seq_len, channels = x.shape
assert seq_len == height * width, "input feature has wrong size"
x = x.view(B, H, W, C)
x = x.view(batch_size, height, width, channels)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
pad_r = (self.window_size - width % self.window_size) % self.window_size
pad_b = (self.window_size - height % self.window_size) % self.window_size
x = functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, height_padded, width_padded, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
x = x.view(-1, self.window_size * self.window_size, channels)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
batch_windows, num_tokens, channels = x.shape
qkv = (
self.qkv(x)
.reshape(batch_windows, num_tokens, 3, self.num_heads, channels // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = (attn @ v).transpose(1, 2).reshape(batch_windows, num_tokens, channels)
x = self.proj(x)
# merge windows
x = x.view(-1, self.window_size, self.window_size, C)
x = window_reverse(x, B, self.window_size, Hp, Wp)
x = x.view(-1, self.window_size, self.window_size, channels)
x = window_reverse(x, batch_size, self.window_size, height_padded, width_padded)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x[:, :height, :width, :].contiguous()
x = x.view(B, H * W, C)
x = x.view(batch_size, height * width, channels)
return x, size
@@ -606,7 +610,7 @@ class DaViT(nn.Module):
x (_type_): input image tensor
"""
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
for conv, block in zip(self.convs, self.blocks, strict=False):
x, input_size = conv(x, input_size)
if self.enable_checkpoint:
x, input_size = checkpoint.checkpoint(block, x, input_size)
@@ -656,7 +660,7 @@ def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens = functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
@@ -1184,7 +1188,7 @@ class Florence2SdpaAttention(Florence2Attention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
is_causal = bool(self.is_causal and attention_mask is None and tgt_len > 1)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
@@ -1440,7 +1444,7 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
for name, _ in module.named_parameters():
if name == "bias":
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d):
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(module.weight, 1.0)
nn.init.constant_(module.bias, 0)
@@ -1594,12 +1598,11 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
if head_mask is not None and head_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -1845,12 +1848,11 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -1860,14 +1862,13 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip(
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"], strict=False
):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
if attn_mask is not None and attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -2494,37 +2495,39 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
def forward(self, pixel_values):
if len(pixel_values.shape) == 4:
batch_size, C, H, W = pixel_values.shape
T = 1
batch_size, channels, height, width = pixel_values.shape
num_frames = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
raise ValueError(f"invalid image shape {pixel_values.shape}")
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
x = x.view(batch_size * num_frames, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, "only support square feature maps for now"
x = x.view(batch_size * T, h, w, x.shape[-1])
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
)
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
1, num_frames, 1, x.shape[-1]
)
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
x_feat_dict["last_frame"] = x
new_x = []
@@ -2619,37 +2622,39 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
def _encode_image(self, pixel_values):
if len(pixel_values.shape) == 4:
batch_size, C, H, W = pixel_values.shape
T = 1
batch_size, channels, height, width = pixel_values.shape
num_frames = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
raise ValueError(f"invalid image shape {pixel_values.shape}")
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
x = x.view(batch_size * num_frames, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, "only support square feature maps for now"
x = x.view(batch_size * T, h, w, x.shape[-1])
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
)
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
1, num_frames, 1, x.shape[-1]
)
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
x_feat_dict["last_frame"] = x
new_x = []
+2 -1
View File
@@ -20,6 +20,7 @@ from __future__ import annotations
import os
from collections import deque
from pathlib import Path
import torch
import torch.nn.functional as F # noqa: N812
@@ -31,7 +32,7 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .configuration_xvla import XVLAConfig, XVLAConfig as PreTrainedConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .transformer import SoftPromptedTransformer
+2 -5
View File
@@ -120,7 +120,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
keys_to_scale = self.image_keys
if keys_to_scale is None:
# Auto-detect image keys
keys_to_scale = [k for k in obs.keys() if k.startswith("observation.images.")]
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
# Scale each image
for key in keys_to_scale:
@@ -161,10 +161,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
"""Add domain_id to complementary data."""
new_transition = transition.copy()
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if comp is None:
comp = {}
else:
comp = comp.copy()
comp = {} if comp is None else comp.copy()
# Infer batch size from observation tensors
obs = new_transition.get(TransitionKey.OBSERVATION, {})
+30 -28
View File
@@ -23,7 +23,7 @@ from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as functional
# ------------------------------- Small utils ----------------------------------
@@ -38,7 +38,7 @@ def _to_2tuple(x) -> tuple:
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(F, "scaled_dot_product_attention")
return hasattr(functional, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
@@ -127,38 +127,38 @@ class Attention(nn.Module):
"""
Parameters
----------
x : Tensor, shape [B, T, C]
x : Tensor, shape [batch_size, seq_len, channels]
Input sequence.
Returns
-------
Tensor, shape [B, T, C]
Tensor, shape [batch_size, seq_len, channels]
Output sequence after MHSA + projection.
"""
B, T, C = x.shape
batch_size, seq_len, channels = x.shape
qkv = (
self.qkv(x)
.reshape(B, T, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
)
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
x = functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [B, H, T, Dh]
) # [batch_size, num_heads, seq_len, head_dim]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, H, T, Dh]
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
x = self.proj(x)
x = self.proj_drop(x)
return x
@@ -240,17 +240,17 @@ class DomainAwareLinear(nn.Module):
Returns
-------
Tensor
[B, O] or [B, T, O]
[batch_size, output_size] or [batch_size, seq_len, output_size]
"""
B = domain_id.shape[0]
squeeze_T = False
batch_size = domain_id.shape[0]
squeeze_seq = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_T = True
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
b = self.bias(domain_id).view(B, self.output_size)
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
if squeeze_T:
squeeze_seq = True
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
bias = self.bias(domain_id).view(batch_size, self.output_size)
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
if squeeze_seq:
y = y.squeeze(1)
return y
@@ -370,16 +370,16 @@ class SoftPromptedTransformer(nn.Module):
Returns
-------
Tensor
Predicted actions, [B, T_action, dim_action]
Predicted actions, [batch_size, num_actions, dim_action]
"""
B, num_actions = action_with_noise.shape[:2]
batch_size, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
# Project visual streams and concatenate
if self.use_hetero_proj:
@@ -402,7 +402,9 @@ class SoftPromptedTransformer(nn.Module):
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
soft_prompts = self.soft_prompt_hub(domain_id).view(
batch_size, self.len_soft_prompts, self.hidden_size
)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
+5 -5
View File
@@ -1,5 +1,5 @@
import numpy as np
import robosuite.utils.transform_utils as T
import robosuite.utils.transform_utils as transform_utils
def rotate6d_to_axis_angle(r6d):
@@ -26,12 +26,12 @@ def rotate6d_to_axis_angle(r6d):
# b3
b3 = np.cross(b1, b2, axis=-1)
R = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
axis_angle_list = []
for i in range(R.shape[0]):
quat = T.mat2quat(R[i])
axis_angle = T.quat2axisangle(quat)
for i in range(rotation_matrix.shape[0]):
quat = transform_utils.mat2quat(rotation_matrix[i])
axis_angle = transform_utils.quat2axisangle(quat)
axis_angle_list.append(axis_angle)
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
-2
View File
@@ -51,8 +51,6 @@ from lerobot.utils.utils import (
init_logging,
)
# login to hf
def update_policy(
train_metrics: MetricsTracker,