mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
major pre-commit cleanup
This commit is contained in:
@@ -1,57 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
|
||||
# append 3 random image to a list
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
for _ in range(num_images):
|
||||
# Random RGB image
|
||||
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
# Multimodal preprocessing by processor
|
||||
inputs = processor(images, language_instruction)
|
||||
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
}
|
||||
)
|
||||
|
||||
# Inference
|
||||
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
breakpoint()
|
||||
@@ -1,63 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
# create an observation dict
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
|
||||
def fake_rgb(H, W):
|
||||
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
|
||||
return img
|
||||
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
# observation = preprocessor(OBS)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
||||
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
breakpoint()
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
env_cfg=env_cfg,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
target_eef = action[:, :3].to("cpu").numpy()
|
||||
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
|
||||
target_act = action[:, 9:10].to("cpu").numpy()
|
||||
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
@@ -1,237 +0,0 @@
|
||||
import os
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
|
||||
cfg = make_policy_config("xvla")
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
|
||||
|
||||
for name, param in policy.state_dict().items():
|
||||
print(name, param.shape)
|
||||
|
||||
|
||||
# now let's load in safetensors
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model"
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
|
||||
# policy.load_state_dict(state_dict)
|
||||
# 3. Add "model." prefix to every key
|
||||
new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||
keys_to_skip = [
|
||||
"model.transformer.action_encoder.fc.weight",
|
||||
"model.transformer.action_encoder.fc.bias",
|
||||
]
|
||||
|
||||
new_state_dict = {k: v for k, v in new_state_dict.items() if k not in keys_to_skip}
|
||||
# 4. Load into your model
|
||||
missing, unexpected = policy.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
print("missing keys:", missing)
|
||||
|
||||
print()
|
||||
print("unexpected keys:", unexpected)
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
observation_height: int = 224
|
||||
observation_width: int = 224 # todo: jadechoghari, image size is different for the two models
|
||||
# create an observation dict
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
|
||||
|
||||
def fake_rgb(H, W):
|
||||
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
||||
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||
t = t.unsqueeze(0).float()
|
||||
# normalize pixel to imagenet
|
||||
return t
|
||||
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
env_cfg=env_cfg,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
inputs = policy._build_model_inputs(observation)
|
||||
|
||||
|
||||
#### now the og model ###########################################################
|
||||
from xvla.models.processing_xvla import XVLAProcessor
|
||||
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
domain_id = torch.tensor([3], dtype=torch.long)
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": OBS[OBS_STATE].to("cuda"),
|
||||
"domain_id": domain_id.to("cuda"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
a = inputs[k]
|
||||
b = inputs_1[k].to("cuda")
|
||||
|
||||
print(f"\n🔎 Key: {k}")
|
||||
|
||||
# Check shape
|
||||
print(" shape:", a.shape, b.shape)
|
||||
|
||||
# Check if close
|
||||
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||
print(" ✔️ tensors are equal (allclose)")
|
||||
else:
|
||||
diff = torch.abs(a - b)
|
||||
print(" ❌ tensors differ")
|
||||
print(" max diff:", diff.max().item())
|
||||
print(" mean diff:", diff.mean().item())
|
||||
|
||||
|
||||
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
# (Pdb) inputs['input_ids'].shape
|
||||
# torch.Size([1, 64])
|
||||
# (Pdb) inputs_1['input_ids'].shape
|
||||
# torch.Size([1, 50])
|
||||
# (Pdb) [0, 0, :, :4, 0]
|
||||
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
# np all close
|
||||
print(np.allclose(action, action_1, atol=1e-2, rtol=1e-2))
|
||||
print("max diff:", np.max(np.abs(action - action_1)))
|
||||
print("mean diff:", np.mean(np.abs(action - action_1)))
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from xvla.models.configuration_xvla import XVLAConfig
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
from xvla.models.processor_xvla import XVLAProcessor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy
|
||||
|
||||
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
# /raid/jade/models/xvla-libero
|
||||
# seet seed
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
for _ in range(num_images):
|
||||
# Random RGB image
|
||||
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
# Multimodal preprocessing by processor
|
||||
inputs = processor(images, language_instruction)
|
||||
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
}
|
||||
)
|
||||
|
||||
# Inference
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
|
||||
#### now for lerobot model #####################################################
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||
policy = make_policy(cfg=cfg, env_cfg=env_cfg)
|
||||
policy.eval()
|
||||
policy.to("cuda")
|
||||
|
||||
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
@@ -101,10 +101,10 @@ class BaseActionSpace(nn.Module):
|
||||
# =============================================================================
|
||||
# Utilities
|
||||
# =============================================================================
|
||||
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
|
||||
bad = [i for i in idx if i < 0 or i >= D]
|
||||
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
|
||||
bad = [i for i in idx if i < 0 or i >= dim_action]
|
||||
if bad:
|
||||
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
|
||||
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -132,8 +132,8 @@ class EE6DActionSpace(BaseActionSpace):
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
# Gripper BCE
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
@@ -188,13 +188,13 @@ class JointActionSpace(BaseActionSpace):
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
|
||||
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||
|
||||
return {
|
||||
@@ -237,8 +237,8 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
B, T, D = pred.shape
|
||||
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
@@ -280,7 +280,7 @@ class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
B, T, D = pred.shape
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
|
||||
@@ -12,12 +12,11 @@
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -82,37 +81,54 @@ class Florence2VisionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
drop_path_rate=0.1,
|
||||
patch_size=[7, 3, 3, 3],
|
||||
patch_stride=[4, 2, 2, 2],
|
||||
patch_padding=[3, 1, 1, 1],
|
||||
patch_prenorm=[False, True, True, True],
|
||||
patch_size=None,
|
||||
patch_stride=None,
|
||||
patch_padding=None,
|
||||
patch_prenorm=None,
|
||||
enable_checkpoint=False,
|
||||
dim_embed=[256, 512, 1024, 2048],
|
||||
num_heads=[8, 16, 32, 64],
|
||||
num_groups=[8, 16, 32, 64],
|
||||
depths=[1, 1, 9, 1],
|
||||
dim_embed=None,
|
||||
num_heads=None,
|
||||
num_groups=None,
|
||||
depths=None,
|
||||
window_size=12,
|
||||
projection_dim=1024,
|
||||
visual_temporal_embedding=None,
|
||||
image_pos_embed=None,
|
||||
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
|
||||
image_feature_source=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
self.patch_padding = patch_padding
|
||||
self.patch_prenorm = patch_prenorm
|
||||
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
|
||||
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
|
||||
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
|
||||
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
|
||||
self.enable_checkpoint = enable_checkpoint
|
||||
self.dim_embed = dim_embed
|
||||
self.num_heads = num_heads
|
||||
self.num_groups = num_groups
|
||||
self.depths = depths
|
||||
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
|
||||
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
|
||||
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
|
||||
self.depths = depths if depths is not None else [1, 1, 9, 1]
|
||||
self.window_size = window_size
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
if visual_temporal_embedding is None:
|
||||
visual_temporal_embedding = {
|
||||
"type": "COSINE",
|
||||
"max_temporal_embeddings": 100,
|
||||
}
|
||||
self.visual_temporal_embedding = visual_temporal_embedding
|
||||
|
||||
if image_pos_embed is None:
|
||||
image_pos_embed = {
|
||||
"type": "learned_abs_2d",
|
||||
"max_pos_embeddings": 1000,
|
||||
}
|
||||
self.image_pos_embed = image_pos_embed
|
||||
self.image_feature_source = image_feature_source
|
||||
|
||||
self.image_feature_source = (
|
||||
image_feature_source
|
||||
if image_feature_source is not None
|
||||
else ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -264,7 +280,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||
"The config can simply be saved and uploaded again to be fixed."
|
||||
"The config can simply be saved and uploaded again to be fixed.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -116,54 +116,12 @@ class XVLAConfig(PreTrainedConfig):
|
||||
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||
"""
|
||||
if self._florence_config_obj is None:
|
||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||
config_dict = dict(self.florence_config)
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
raise ValueError("vision_config is required")
|
||||
# # Provide default vision config
|
||||
# config_dict["vision_config"] = {
|
||||
# "model_type": "davit",
|
||||
# "drop_path_rate": 0.1,
|
||||
# "patch_size": [7, 3, 3, 3],
|
||||
# "patch_stride": [4, 2, 2, 2],
|
||||
# "patch_padding": [3, 1, 1, 1],
|
||||
# "patch_prenorm": [False, True, True, True],
|
||||
# "enable_checkpoint": False,
|
||||
# "dim_embed": [256, 512, 1024, 2048],
|
||||
# "num_heads": [8, 16, 32, 64],
|
||||
# "num_groups": [8, 16, 32, 64],
|
||||
# "depths": [1, 1, 9, 1],
|
||||
# "window_size": 12,
|
||||
# "projection_dim": 1024,
|
||||
# "visual_temporal_embedding": {
|
||||
# "type": "COSINE",
|
||||
# "max_temporal_embeddings": 100
|
||||
# },
|
||||
# "image_pos_embed": {
|
||||
# "type": "learned_abs_2d",
|
||||
# "max_pos_embeddings": 50
|
||||
# },
|
||||
# "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
# }
|
||||
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
raise ValueError("text_config is required")
|
||||
# # Provide default text config
|
||||
# config_dict["text_config"] = {
|
||||
# "vocab_size": 51289,
|
||||
# "activation_dropout": 0.1,
|
||||
# "activation_function": "gelu",
|
||||
# "attention_dropout": 0.1,
|
||||
# "d_model": 1024,
|
||||
# "decoder_attention_heads": 16,
|
||||
# "decoder_layers": 12,
|
||||
# "encoder_attention_heads": 16,
|
||||
# "encoder_layers": 12,
|
||||
# "dropout": 0.1,
|
||||
# "max_position_embeddings": 4096,
|
||||
# "num_hidden_layers": 12,
|
||||
# "num_beams": 3
|
||||
# }
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.functional as functional
|
||||
import torch.utils.checkpoint
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from einops import rearrange
|
||||
@@ -190,10 +190,7 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
|
||||
class MySequential(nn.Sequential):
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if type(inputs) == tuple:
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
inputs = module(*inputs) if isinstance(inputs, tuple) else module(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -206,7 +203,7 @@ class PreNorm(nn.Module):
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
shortcut = x
|
||||
if self.norm != None:
|
||||
if self.norm is not None:
|
||||
x, size = self.fn(self.norm(x), *args, **kwargs)
|
||||
else:
|
||||
x, size = self.fn(x, *args, **kwargs)
|
||||
@@ -259,11 +256,11 @@ class DepthWiseConv2d(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, size):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == H * W
|
||||
batch_size, num_tokens, channels = x.shape
|
||||
height, width = size
|
||||
assert num_tokens == height * width
|
||||
|
||||
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
|
||||
x = self.dw(x.transpose(1, 2).view(batch_size, channels, height, width))
|
||||
size = (x.size(-2), x.size(-1))
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, size
|
||||
@@ -286,20 +283,20 @@ class ConvEmbed(nn.Module):
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
def forward(self, x, size):
|
||||
H, W = size
|
||||
height, width = size
|
||||
if len(x.size()) == 3:
|
||||
if self.norm and self.pre_norm:
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=height, w=width)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
_, _, H, W = x.shape
|
||||
_, _, height, width = x.shape
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.norm and not self.pre_norm:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, (H, W)
|
||||
return x, (height, width)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
@@ -311,16 +308,20 @@ class ChannelAttention(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, size):
|
||||
B, N, C = x.shape
|
||||
batch_size, num_tokens, channels = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(batch_size, num_tokens, 3, self.groups, channels // self.groups)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * (float(N) ** -0.5)
|
||||
q = q * (float(num_tokens) ** -0.5)
|
||||
attention = q.transpose(-1, -2) @ k
|
||||
attention = attention.softmax(dim=-1)
|
||||
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return x, size
|
||||
|
||||
@@ -366,18 +367,17 @@ class ChannelBlock(nn.Module):
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
batch_size, height, width, channels = x.shape
|
||||
x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
|
||||
B = batch_size
|
||||
def window_reverse(windows, batch_size: int, window_size: int, height: int, width: int):
|
||||
# this will cause onnx conversion failed for dynamic axis, because treated as constant
|
||||
# int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
# int(windows.shape[0] / (height * width / window_size / window_size))
|
||||
x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
||||
return x
|
||||
|
||||
|
||||
@@ -396,43 +396,47 @@ class WindowAttention(nn.Module):
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, size):
|
||||
H, W = size
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
height, width = size
|
||||
batch_size, seq_len, channels = x.shape
|
||||
assert seq_len == height * width, "input feature has wrong size"
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
x = x.view(batch_size, height, width, channels)
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
pad_r = (self.window_size - width % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - height % self.window_size) % self.window_size
|
||||
x = functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, height_padded, width_padded, _ = x.shape
|
||||
|
||||
x = window_partition(x, self.window_size)
|
||||
x = x.view(-1, self.window_size * self.window_size, C)
|
||||
x = x.view(-1, self.window_size * self.window_size, channels)
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
# attn_windows = self.attn(x_windows)
|
||||
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
batch_windows, num_tokens, channels = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(batch_windows, num_tokens, 3, self.num_heads, channels // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = (attn @ v).transpose(1, 2).reshape(batch_windows, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
|
||||
# merge windows
|
||||
x = x.view(-1, self.window_size, self.window_size, C)
|
||||
x = window_reverse(x, B, self.window_size, Hp, Wp)
|
||||
x = x.view(-1, self.window_size, self.window_size, channels)
|
||||
x = window_reverse(x, batch_size, self.window_size, height_padded, width_padded)
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
x = x[:, :height, :width, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
x = x.view(batch_size, height * width, channels)
|
||||
|
||||
return x, size
|
||||
|
||||
@@ -606,7 +610,7 @@ class DaViT(nn.Module):
|
||||
x (_type_): input image tensor
|
||||
"""
|
||||
input_size = (x.size(2), x.size(3))
|
||||
for conv, block in zip(self.convs, self.blocks):
|
||||
for conv, block in zip(self.convs, self.blocks, strict=False):
|
||||
x, input_size = conv(x, input_size)
|
||||
if self.enable_checkpoint:
|
||||
x, input_size = checkpoint.checkpoint(block, x, input_size)
|
||||
@@ -656,7 +660,7 @@ def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
cu_seqlens = functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
@@ -1184,7 +1188,7 @@ class Florence2SdpaAttention(Florence2Attention):
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||
is_causal = bool(self.is_causal and attention_mask is None and tgt_len > 1)
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
@@ -1440,7 +1444,7 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
|
||||
for name, _ in module.named_parameters():
|
||||
if name == "bias":
|
||||
nn.init.constant_(module.bias, 0)
|
||||
elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d):
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
@@ -1594,12 +1598,11 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
if head_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
if head_mask is not None and head_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -1845,12 +1848,11 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@@ -1860,14 +1862,13 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip(
|
||||
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
||||
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"], strict=False
|
||||
):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
if attn_mask is not None and attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
@@ -2494,37 +2495,39 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
||||
|
||||
def forward(self, pixel_values):
|
||||
if len(pixel_values.shape) == 4:
|
||||
batch_size, C, H, W = pixel_values.shape
|
||||
T = 1
|
||||
batch_size, channels, height, width = pixel_values.shape
|
||||
num_frames = 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * T, -1, x.shape[-1])
|
||||
x = x.view(batch_size * num_frames, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
assert h * w == num_tokens, "only support square feature maps for now"
|
||||
x = x.view(batch_size * T, h, w, x.shape[-1])
|
||||
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, T * h * w, x.shape[-1])
|
||||
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(
|
||||
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
|
||||
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
|
||||
)
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
|
||||
1, num_frames, 1, x.shape[-1]
|
||||
)
|
||||
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
||||
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
||||
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict["last_frame"] = x
|
||||
|
||||
new_x = []
|
||||
@@ -2619,37 +2622,39 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||
|
||||
def _encode_image(self, pixel_values):
|
||||
if len(pixel_values.shape) == 4:
|
||||
batch_size, C, H, W = pixel_values.shape
|
||||
T = 1
|
||||
batch_size, channels, height, width = pixel_values.shape
|
||||
num_frames = 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * T, -1, x.shape[-1])
|
||||
x = x.view(batch_size * num_frames, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
assert h * w == num_tokens, "only support square feature maps for now"
|
||||
x = x.view(batch_size * T, h, w, x.shape[-1])
|
||||
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, T * h * w, x.shape[-1])
|
||||
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(
|
||||
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
|
||||
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
|
||||
)
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
|
||||
1, num_frames, 1, x.shape[-1]
|
||||
)
|
||||
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
||||
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
||||
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict["last_frame"] = x
|
||||
|
||||
new_x = []
|
||||
|
||||
@@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
@@ -31,7 +32,7 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .configuration_xvla import XVLAConfig, XVLAConfig as PreTrainedConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .transformer import SoftPromptedTransformer
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||
keys_to_scale = self.image_keys
|
||||
if keys_to_scale is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_scale = [k for k in obs.keys() if k.startswith("observation.images.")]
|
||||
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
|
||||
|
||||
# Scale each image
|
||||
for key in keys_to_scale:
|
||||
@@ -161,10 +161,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
"""Add domain_id to complementary data."""
|
||||
new_transition = transition.copy()
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp is None:
|
||||
comp = {}
|
||||
else:
|
||||
comp = comp.copy()
|
||||
comp = {} if comp is None else comp.copy()
|
||||
|
||||
# Infer batch size from observation tensors
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.functional as functional
|
||||
|
||||
# ------------------------------- Small utils ----------------------------------
|
||||
|
||||
@@ -38,7 +38,7 @@ def _to_2tuple(x) -> tuple:
|
||||
|
||||
def _has_sdp_attention() -> bool:
|
||||
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
return hasattr(functional, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
# ---------------------------------- MLP --------------------------------------
|
||||
@@ -127,38 +127,38 @@ class Attention(nn.Module):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor, shape [B, T, C]
|
||||
x : Tensor, shape [batch_size, seq_len, channels]
|
||||
Input sequence.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor, shape [B, T, C]
|
||||
Tensor, shape [batch_size, seq_len, channels]
|
||||
Output sequence after MHSA + projection.
|
||||
"""
|
||||
B, T, C = x.shape
|
||||
batch_size, seq_len, channels = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
|
||||
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
|
||||
)
|
||||
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
|
||||
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
x = functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
) # [B, H, T, Dh]
|
||||
) # [batch_size, num_heads, seq_len, head_dim]
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
|
||||
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v # [B, H, T, Dh]
|
||||
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
|
||||
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@@ -240,17 +240,17 @@ class DomainAwareLinear(nn.Module):
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
[B, O] or [B, T, O]
|
||||
[batch_size, output_size] or [batch_size, seq_len, output_size]
|
||||
"""
|
||||
B = domain_id.shape[0]
|
||||
squeeze_T = False
|
||||
batch_size = domain_id.shape[0]
|
||||
squeeze_seq = False
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
squeeze_T = True
|
||||
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
|
||||
b = self.bias(domain_id).view(B, self.output_size)
|
||||
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
|
||||
if squeeze_T:
|
||||
squeeze_seq = True
|
||||
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
|
||||
bias = self.bias(domain_id).view(batch_size, self.output_size)
|
||||
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
|
||||
if squeeze_seq:
|
||||
y = y.squeeze(1)
|
||||
return y
|
||||
|
||||
@@ -370,16 +370,16 @@ class SoftPromptedTransformer(nn.Module):
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
Predicted actions, [B, T_action, dim_action]
|
||||
Predicted actions, [batch_size, num_actions, dim_action]
|
||||
"""
|
||||
B, num_actions = action_with_noise.shape[:2]
|
||||
batch_size, num_actions = action_with_noise.shape[:2]
|
||||
|
||||
# Encode (action + proprio + time) → tokens
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
|
||||
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
|
||||
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
|
||||
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
|
||||
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
|
||||
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
|
||||
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
|
||||
|
||||
# Project visual streams and concatenate
|
||||
if self.use_hetero_proj:
|
||||
@@ -402,7 +402,9 @@ class SoftPromptedTransformer(nn.Module):
|
||||
|
||||
# Append soft prompts
|
||||
if self.len_soft_prompts > 0:
|
||||
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
|
||||
soft_prompts = self.soft_prompt_hub(domain_id).view(
|
||||
batch_size, self.len_soft_prompts, self.hidden_size
|
||||
)
|
||||
x = torch.cat([x, soft_prompts], dim=1)
|
||||
|
||||
# Transformer backbone
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
import robosuite.utils.transform_utils as T
|
||||
import robosuite.utils.transform_utils as transform_utils
|
||||
|
||||
|
||||
def rotate6d_to_axis_angle(r6d):
|
||||
@@ -26,12 +26,12 @@ def rotate6d_to_axis_angle(r6d):
|
||||
# b3
|
||||
b3 = np.cross(b1, b2, axis=-1)
|
||||
|
||||
R = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||
|
||||
axis_angle_list = []
|
||||
for i in range(R.shape[0]):
|
||||
quat = T.mat2quat(R[i])
|
||||
axis_angle = T.quat2axisangle(quat)
|
||||
for i in range(rotation_matrix.shape[0]):
|
||||
quat = transform_utils.mat2quat(rotation_matrix[i])
|
||||
axis_angle = transform_utils.quat2axisangle(quat)
|
||||
axis_angle_list.append(axis_angle)
|
||||
|
||||
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||
|
||||
@@ -51,8 +51,6 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
)
|
||||
|
||||
# login to hf
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
|
||||
Reference in New Issue
Block a user