diff --git a/examples/tutorial/xvla/inference.py b/examples/tutorial/xvla/inference.py deleted file mode 100644 index 8f59adaeb..000000000 --- a/examples/tutorial/xvla/inference.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -import torch -from PIL import Image -from transformers import AutoModel, AutoProcessor - -model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True) - -processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True) - - -# append 3 random image to a list -def make_random_pil_images(num_images=3, H=480, W=640): - images = [] - for _ in range(num_images): - # Random RGB image - arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8) - img = Image.fromarray(arr) - images.append(img) - return images - - -# Example: -images = make_random_pil_images() -language_instruction = "This is a random image" -# Multimodal preprocessing by processor -inputs = processor(images, language_instruction) -if not {"input_ids", "image_input", "image_mask"}.issubset(inputs): - raise ValueError("Processor did not return the expected keys.") - -proprio = torch.randn(1, 20) -domain_id = torch.tensor([0], dtype=torch.long) - -# Align to model's device/dtype -device = model.device -dtype = next(model.parameters()).dtype - - -def to_model(t: torch.Tensor) -> torch.Tensor: - if not isinstance(t, torch.Tensor): - t = torch.as_tensor(t) - # cast floats to model dtype, keep integral/bool as-is - return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device) - - -inputs = {k: to_model(v) for k, v in inputs.items()} -inputs.update( - { - "proprio": to_model(proprio), - "domain_id": domain_id.to(device), - } -) - -# Inference - -action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() - -breakpoint() diff --git a/examples/tutorial/xvla/inference_pipe.py b/examples/tutorial/xvla/inference_pipe.py deleted file mode 100644 index 93bff30ca..000000000 --- a/examples/tutorial/xvla/inference_pipe.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -import torch - -# from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.configs.policies import PreTrainedConfig -from lerobot.envs.factory import make_env_config -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.xvla.utils import rotate6d_to_axis_angle -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE - -observation_height: int = 360 -observation_width: int = 360 -# create an observation dict -OBS = { - f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width), - f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width), - OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string - "task": "put the object in the box", -} - - -def fake_rgb(H, W): - img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy() - return img - - -OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width) -OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width) - -# observation = preprocessor(OBS) -from transformers import AutoProcessor - -processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True) -inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"]) -breakpoint() - -cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated") -cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated" -env_cfg = make_env_config("libero", task="libero_spatial") -policy = make_policy( - cfg=cfg, - env_cfg=env_cfg, -) - -policy.eval() - -preprocessor_overrides = { - "device_processor": {"device": str(cfg.device)}, -} - -preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg, - pretrained_path=cfg.pretrained_path, - preprocessor_overrides=preprocessor_overrides, -) - -observation = preprocessor(OBS) -action = policy.select_action(observation) - -target_eef = action[:, :3].to("cpu").numpy() -target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy()) -target_act = action[:, 9:10].to("cpu").numpy() -final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1) diff --git a/examples/tutorial/xvla/test_xvla.py b/examples/tutorial/xvla/test_xvla.py deleted file mode 100644 index d001b9877..000000000 --- a/examples/tutorial/xvla/test_xvla.py +++ /dev/null @@ -1,237 +0,0 @@ -import os - -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.policies.factory import make_policy, make_policy_config - -cfg = make_policy_config("xvla") - -dataset_id = "lerobot/svla_so101_pickplace" -# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets -dataset_metadata = LeRobotDatasetMetadata(dataset_id) -policy = make_policy(cfg=cfg, ds_meta=dataset_metadata) - -for name, param in policy.state_dict().items(): - print(name, param.shape) - - -# now let's load in safetensors -import safetensors.torch -from huggingface_hub import snapshot_download - -cache_dir = snapshot_download( - repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model" -) -state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors")) -# policy.load_state_dict(state_dict) -# 3. Add "model." prefix to every key -new_state_dict = {f"model.{k}": v for k, v in state_dict.items()} -keys_to_skip = [ - "model.transformer.action_encoder.fc.weight", - "model.transformer.action_encoder.fc.bias", -] - -new_state_dict = {k: v for k, v in new_state_dict.items() if k not in keys_to_skip} -# 4. Load into your model -missing, unexpected = policy.load_state_dict(new_state_dict, strict=False) - -print("missing keys:", missing) - -print() -print("unexpected keys:", unexpected) - - -import random - -import numpy as np -import torch -from xvla.models.modeling_xvla import XVLA - -# from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.configs.policies import PreTrainedConfig -from lerobot.envs.factory import make_env_config -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE - -torch.manual_seed(42) -random.seed(42) -np.random.seed(42) -observation_height: int = 224 -observation_width: int = 224 # todo: jadechoghari, image size is different for the two models -# create an observation dict -OBS = { - f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width), - f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width), - OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string - "task": "put the object in the box", -} - -IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) -IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) - - -def fake_rgb(H, W): - arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8) - t = torch.from_numpy(arr).permute(2, 0, 1) # CHW - t = t.unsqueeze(0).float() - # normalize pixel to imagenet - return t - - -OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width) -OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width) - -cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated") -cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated" -env_cfg = make_env_config("libero", task="libero_spatial") -policy = make_policy( - cfg=cfg, - env_cfg=env_cfg, -) - -policy.eval() - -preprocessor_overrides = { - "device_processor": {"device": str(cfg.device)}, -} - -preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg, - pretrained_path=cfg.pretrained_path, - preprocessor_overrides=preprocessor_overrides, -) - -observation = preprocessor(OBS) -inputs = policy._build_model_inputs(observation) - - -#### now the og model ########################################################### -from xvla.models.processing_xvla import XVLAProcessor - -processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2) -inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"]) -domain_id = torch.tensor([3], dtype=torch.long) -inputs.update( - { - "proprio": OBS[OBS_STATE].to("cuda"), - "domain_id": domain_id.to("cuda"), - } -) - - -for k in inputs.keys() & inputs_1.keys(): # intersection of keys - a = inputs[k] - b = inputs_1[k].to("cuda") - - print(f"\nšŸ”Ž Key: {k}") - - # Check shape - print(" shape:", a.shape, b.shape) - - # Check if close - if torch.allclose(a, b, atol=1e-5, rtol=1e-5): - print(" āœ”ļø tensors are equal (allclose)") - else: - diff = torch.abs(a - b) - print(" āŒ tensors differ") - print(" max diff:", diff.max().item()) - print(" mean diff:", diff.mean().item()) - - -model = XVLA.from_pretrained("/raid/jade/models/xvla-libero") -model.eval() -model.to("cuda") - -action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() -# (Pdb) inputs['input_ids'].shape -# torch.Size([1, 64]) -# (Pdb) inputs_1['input_ids'].shape -# torch.Size([1, 50]) -# (Pdb) [0, 0, :, :4, 0] -action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() - -# np all close -print(np.allclose(action, action_1, atol=1e-2, rtol=1e-2)) -print("max diff:", np.max(np.abs(action - action_1))) -print("mean diff:", np.mean(np.abs(action - action_1))) - - -import random - -import numpy as np -import torch -from PIL import Image -from xvla.models.configuration_xvla import XVLAConfig -from xvla.models.modeling_xvla import XVLA -from xvla.models.processor_xvla import XVLAProcessor - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.envs.factory import make_env_config -from lerobot.policies.factory import make_policy - -cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero") -model = XVLA.from_pretrained("/raid/jade/models/xvla-libero") -model.eval() -model.to("cuda") -processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero") -# /raid/jade/models/xvla-libero -# seet seed -torch.manual_seed(42) -random.seed(42) -np.random.seed(42) - - -def make_random_pil_images(num_images=3, H=480, W=640): - images = [] - for _ in range(num_images): - # Random RGB image - arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8) - img = Image.fromarray(arr) - images.append(img) - return images - - -# Example: -images = make_random_pil_images() -language_instruction = "This is a random image" -# Multimodal preprocessing by processor -inputs = processor(images, language_instruction) -if not {"input_ids", "image_input", "image_mask"}.issubset(inputs): - raise ValueError("Processor did not return the expected keys.") - -proprio = torch.randn(1, 20) -domain_id = torch.tensor([0], dtype=torch.long) - -# Align to model's device/dtype -device = model.device -dtype = next(model.parameters()).dtype - - -def to_model(t: torch.Tensor) -> torch.Tensor: - if not isinstance(t, torch.Tensor): - t = torch.as_tensor(t) - # cast floats to model dtype, keep integral/bool as-is - return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device) - - -inputs = {k: to_model(v) for k, v in inputs.items()} -inputs.update( - { - "proprio": to_model(proprio), - "domain_id": domain_id.to(device), - } -) - -# Inference -action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() - - -#### now for lerobot model ##################################################### - -cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated") -env_cfg = make_env_config("libero", task="libero_spatial") -cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated" -policy = make_policy(cfg=cfg, env_cfg=env_cfg) -policy.eval() -policy.to("cuda") - -action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py index 8e71cd3cf..e7e6485cd 100644 --- a/src/lerobot/policies/xvla/action_hub.py +++ b/src/lerobot/policies/xvla/action_hub.py @@ -101,10 +101,10 @@ class BaseActionSpace(nn.Module): # ============================================================================= # Utilities # ============================================================================= -def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: - bad = [i for i in idx if i < 0 or i >= D] +def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None: + bad = [i for i in idx if i < 0 or i >= dim_action] if bad: - raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") + raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}") # ============================================================================= @@ -132,8 +132,8 @@ class EE6DActionSpace(BaseActionSpace): def compute_loss(self, pred, target): assert pred.shape == target.shape, "pred/target shapes must match" - B, T, D = pred.shape - _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") + batch_size, seq_len, action_dim = pred.shape + _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx") # Gripper BCE g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] @@ -188,13 +188,13 @@ class JointActionSpace(BaseActionSpace): def compute_loss(self, pred, target): assert pred.shape == target.shape - B, T, D = pred.shape - _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") + batch_size, seq_len, action_dim = pred.shape + _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx") g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE - joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx)) + joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx)) joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE return { @@ -237,8 +237,8 @@ class AGIBOTEE6DActionSpace(BaseActionSpace): def compute_loss(self, pred, target): assert pred.shape == target.shape - B, T, D = pred.shape - _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") + batch_size, seq_len, action_dim = pred.shape + _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx") gripper_loss = ( self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE @@ -280,7 +280,7 @@ class FrankaJoint7ActionSpace(BaseActionSpace): def compute_loss(self, pred, target): assert pred.shape == target.shape, "pred/target shapes must match" - B, T, D = pred.shape + batch_size, seq_len, action_dim = pred.shape joints_loss = self.mse(pred, target) * self.JOINTS_SCALE return {"joints_loss": joints_loss} diff --git a/src/lerobot/policies/xvla/configuration_florence2.py b/src/lerobot/policies/xvla/configuration_florence2.py index 1a66cf3e2..20b32976b 100644 --- a/src/lerobot/policies/xvla/configuration_florence2.py +++ b/src/lerobot/policies/xvla/configuration_florence2.py @@ -12,12 +12,11 @@ # limitations under the License. import warnings -""" Florence-2 configuration""" - - from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging +""" Florence-2 configuration""" + logger = logging.get_logger(__name__) @@ -82,37 +81,54 @@ class Florence2VisionConfig(PretrainedConfig): def __init__( self, drop_path_rate=0.1, - patch_size=[7, 3, 3, 3], - patch_stride=[4, 2, 2, 2], - patch_padding=[3, 1, 1, 1], - patch_prenorm=[False, True, True, True], + patch_size=None, + patch_stride=None, + patch_padding=None, + patch_prenorm=None, enable_checkpoint=False, - dim_embed=[256, 512, 1024, 2048], - num_heads=[8, 16, 32, 64], - num_groups=[8, 16, 32, 64], - depths=[1, 1, 9, 1], + dim_embed=None, + num_heads=None, + num_groups=None, + depths=None, window_size=12, projection_dim=1024, visual_temporal_embedding=None, image_pos_embed=None, - image_feature_source=["spatial_avg_pool", "temporal_avg_pool"], + image_feature_source=None, **kwargs, ): self.drop_path_rate = drop_path_rate - self.patch_size = patch_size - self.patch_stride = patch_stride - self.patch_padding = patch_padding - self.patch_prenorm = patch_prenorm + self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3] + self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2] + self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1] + self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True] self.enable_checkpoint = enable_checkpoint - self.dim_embed = dim_embed - self.num_heads = num_heads - self.num_groups = num_groups - self.depths = depths + self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048] + self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64] + self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64] + self.depths = depths if depths is not None else [1, 1, 9, 1] self.window_size = window_size self.projection_dim = projection_dim + + if visual_temporal_embedding is None: + visual_temporal_embedding = { + "type": "COSINE", + "max_temporal_embeddings": 100, + } self.visual_temporal_embedding = visual_temporal_embedding + + if image_pos_embed is None: + image_pos_embed = { + "type": "learned_abs_2d", + "max_pos_embeddings": 1000, + } self.image_pos_embed = image_pos_embed - self.image_feature_source = image_feature_source + + self.image_feature_source = ( + image_feature_source + if image_feature_source is not None + else ["spatial_avg_pool", "temporal_avg_pool"] + ) super().__init__(**kwargs) @@ -264,7 +280,8 @@ class Florence2LanguageConfig(PretrainedConfig): self.forced_bos_token_id = self.bos_token_id warnings.warn( f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " - "The config can simply be saved and uploaded again to be fixed." + "The config can simply be saved and uploaded again to be fixed.", + stacklevel=2, ) diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index c6c3f1f25..4bfa8704a 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -116,54 +116,12 @@ class XVLAConfig(PreTrainedConfig): Build (and cache) the Florence2 transformer config that should back the VLM. """ if self._florence_config_obj is None: - # TODO: jadechoghari: provide default way, and do not hardcode - # Ensure vision_config and text_config are provided with defaults if not specified config_dict = dict(self.florence_config) if "vision_config" not in config_dict or config_dict["vision_config"] is None: raise ValueError("vision_config is required") - # # Provide default vision config - # config_dict["vision_config"] = { - # "model_type": "davit", - # "drop_path_rate": 0.1, - # "patch_size": [7, 3, 3, 3], - # "patch_stride": [4, 2, 2, 2], - # "patch_padding": [3, 1, 1, 1], - # "patch_prenorm": [False, True, True, True], - # "enable_checkpoint": False, - # "dim_embed": [256, 512, 1024, 2048], - # "num_heads": [8, 16, 32, 64], - # "num_groups": [8, 16, 32, 64], - # "depths": [1, 1, 9, 1], - # "window_size": 12, - # "projection_dim": 1024, - # "visual_temporal_embedding": { - # "type": "COSINE", - # "max_temporal_embeddings": 100 - # }, - # "image_pos_embed": { - # "type": "learned_abs_2d", - # "max_pos_embeddings": 50 - # }, - # "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"] - # } + if "text_config" not in config_dict or config_dict["text_config"] is None: raise ValueError("text_config is required") - # # Provide default text config - # config_dict["text_config"] = { - # "vocab_size": 51289, - # "activation_dropout": 0.1, - # "activation_function": "gelu", - # "attention_dropout": 0.1, - # "d_model": 1024, - # "decoder_attention_heads": 16, - # "decoder_layers": 12, - # "encoder_attention_heads": 16, - # "encoder_layers": 12, - # "dropout": 0.1, - # "max_position_embeddings": 4096, - # "num_hidden_layers": 12, - # "num_beams": 3 - # } self._florence_config_obj = Florence2Config(**config_dict) return self._florence_config_obj diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index e61689568..f783d0818 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -19,7 +19,7 @@ from collections import OrderedDict from dataclasses import dataclass import torch -import torch.nn.functional as F +import torch.nn.functional as functional import torch.utils.checkpoint import torch.utils.checkpoint as checkpoint from einops import rearrange @@ -190,10 +190,7 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module): class MySequential(nn.Sequential): def forward(self, *inputs): for module in self._modules.values(): - if type(inputs) == tuple: - inputs = module(*inputs) - else: - inputs = module(inputs) + inputs = module(*inputs) if isinstance(inputs, tuple) else module(inputs) return inputs @@ -206,7 +203,7 @@ class PreNorm(nn.Module): def forward(self, x, *args, **kwargs): shortcut = x - if self.norm != None: + if self.norm is not None: x, size = self.fn(self.norm(x), *args, **kwargs) else: x, size = self.fn(x, *args, **kwargs) @@ -259,11 +256,11 @@ class DepthWiseConv2d(nn.Module): ) def forward(self, x, size): - B, N, C = x.shape - H, W = size - assert N == H * W + batch_size, num_tokens, channels = x.shape + height, width = size + assert num_tokens == height * width - x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + x = self.dw(x.transpose(1, 2).view(batch_size, channels, height, width)) size = (x.size(-2), x.size(-1)) x = x.flatten(2).transpose(1, 2) return x, size @@ -286,20 +283,20 @@ class ConvEmbed(nn.Module): self.pre_norm = pre_norm def forward(self, x, size): - H, W = size + height, width = size if len(x.size()) == 3: if self.norm and self.pre_norm: x = self.norm(x) - x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + x = rearrange(x, "b (h w) c -> b c h w", h=height, w=width) x = self.proj(x) - _, _, H, W = x.shape + _, _, height, width = x.shape x = rearrange(x, "b c h w -> b (h w) c") if self.norm and not self.pre_norm: x = self.norm(x) - return x, (H, W) + return x, (height, width) class ChannelAttention(nn.Module): @@ -311,16 +308,20 @@ class ChannelAttention(nn.Module): self.proj = nn.Linear(dim, dim) def forward(self, x, size): - B, N, C = x.shape + batch_size, num_tokens, channels = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(batch_size, num_tokens, 3, self.groups, channels // self.groups) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] - q = q * (float(N) ** -0.5) + q = q * (float(num_tokens) ** -0.5) attention = q.transpose(-1, -2) @ k attention = attention.softmax(dim=-1) x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels) x = self.proj(x) return x, size @@ -366,18 +367,17 @@ class ChannelBlock(nn.Module): def window_partition(x, window_size: int): - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + batch_size, height, width, channels = x.shape + x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels) return windows -def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): - B = batch_size +def window_reverse(windows, batch_size: int, window_size: int, height: int, width: int): # this will cause onnx conversion failed for dynamic axis, because treated as constant - # int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + # int(windows.shape[0] / (height * width / window_size / window_size)) + x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) return x @@ -396,43 +396,47 @@ class WindowAttention(nn.Module): self.softmax = nn.Softmax(dim=-1) def forward(self, x, size): - H, W = size - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + height, width = size + batch_size, seq_len, channels = x.shape + assert seq_len == height * width, "input feature has wrong size" - x = x.view(B, H, W, C) + x = x.view(batch_size, height, width, channels) pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape + pad_r = (self.window_size - width % self.window_size) % self.window_size + pad_b = (self.window_size - height % self.window_size) % self.window_size + x = functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, height_padded, width_padded, _ = x.shape x = window_partition(x, self.window_size) - x = x.view(-1, self.window_size * self.window_size, C) + x = x.view(-1, self.window_size * self.window_size, channels) # W-MSA/SW-MSA # attn_windows = self.attn(x_windows) - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + batch_windows, num_tokens, channels = x.shape + qkv = ( + self.qkv(x) + .reshape(batch_windows, num_tokens, 3, self.num_heads, channels // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) attn = self.softmax(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(batch_windows, num_tokens, channels) x = self.proj(x) # merge windows - x = x.view(-1, self.window_size, self.window_size, C) - x = window_reverse(x, B, self.window_size, Hp, Wp) + x = x.view(-1, self.window_size, self.window_size, channels) + x = window_reverse(x, batch_size, self.window_size, height_padded, width_padded) if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() + x = x[:, :height, :width, :].contiguous() - x = x.view(B, H * W, C) + x = x.view(batch_size, height * width, channels) return x, size @@ -606,7 +610,7 @@ class DaViT(nn.Module): x (_type_): input image tensor """ input_size = (x.size(2), x.size(3)) - for conv, block in zip(self.convs, self.blocks): + for conv, block in zip(self.convs, self.blocks, strict=False): x, input_size = conv(x, input_size) if self.enable_checkpoint: x, input_size = checkpoint.checkpoint(block, x, input_size) @@ -656,7 +660,7 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens = functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -1184,7 +1188,7 @@ class Florence2SdpaAttention(Florence2Attention): # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = bool(self.is_causal and attention_mask is None and tgt_len > 1) # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -1440,7 +1444,7 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel): for name, _ in module.named_parameters(): if name == "bias": nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d): + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(module.weight, 1.0) nn.init.constant_(module.bias, 0) @@ -1594,12 +1598,11 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel): all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) + if head_mask is not None and head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1845,12 +1848,11 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1860,14 +1862,13 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip( - [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"], strict=False ): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) + if attn_mask is not None and attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -2494,37 +2495,39 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel): def forward(self, pixel_values): if len(pixel_values.shape) == 4: - batch_size, C, H, W = pixel_values.shape - T = 1 + batch_size, channels, height, width = pixel_values.shape + num_frames = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: raise ValueError(f"invalid image shape {pixel_values.shape}") if self.image_pos_embed is not None: - x = x.view(batch_size * T, -1, x.shape[-1]) + x = x.view(batch_size * num_frames, -1, x.shape[-1]) num_tokens = x.shape[-2] h, w = int(num_tokens**0.5), int(num_tokens**0.5) assert h * w == num_tokens, "only support square feature maps for now" - x = x.view(batch_size * T, h, w, x.shape[-1]) + x = x.view(batch_size * num_frames, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed - x = x.view(batch_size, T * h * w, x.shape[-1]) + x = x.view(batch_size, num_frames * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0] + ) + x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view( + 1, num_frames, 1, x.shape[-1] ) - x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) x_feat_dict = {} - spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2) x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x - temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1) x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x - x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1] x_feat_dict["last_frame"] = x new_x = [] @@ -2619,37 +2622,39 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): def _encode_image(self, pixel_values): if len(pixel_values.shape) == 4: - batch_size, C, H, W = pixel_values.shape - T = 1 + batch_size, channels, height, width = pixel_values.shape + num_frames = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: raise ValueError(f"invalid image shape {pixel_values.shape}") if self.image_pos_embed is not None: - x = x.view(batch_size * T, -1, x.shape[-1]) + x = x.view(batch_size * num_frames, -1, x.shape[-1]) num_tokens = x.shape[-2] h, w = int(num_tokens**0.5), int(num_tokens**0.5) assert h * w == num_tokens, "only support square feature maps for now" - x = x.view(batch_size * T, h, w, x.shape[-1]) + x = x.view(batch_size * num_frames, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed - x = x.view(batch_size, T * h * w, x.shape[-1]) + x = x.view(batch_size, num_frames * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0] + ) + x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view( + 1, num_frames, 1, x.shape[-1] ) - x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) x_feat_dict = {} - spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2) x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x - temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1) x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x - x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1] x_feat_dict["last_frame"] = x new_x = [] diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index a8ee2eae0..549eb301e 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -20,6 +20,7 @@ from __future__ import annotations import os from collections import deque +from pathlib import Path import torch import torch.nn.functional as F # noqa: N812 @@ -31,7 +32,7 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE from .action_hub import build_action_space from .configuration_florence2 import Florence2Config -from .configuration_xvla import XVLAConfig +from .configuration_xvla import XVLAConfig, XVLAConfig as PreTrainedConfig from .modeling_florence2 import Florence2ForConditionalGeneration from .transformer import SoftPromptedTransformer diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 5296f2131..d8a6e7092 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -120,7 +120,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep): keys_to_scale = self.image_keys if keys_to_scale is None: # Auto-detect image keys - keys_to_scale = [k for k in obs.keys() if k.startswith("observation.images.")] + keys_to_scale = [k for k in obs if k.startswith("observation.images.")] # Scale each image for key in keys_to_scale: @@ -161,10 +161,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep): """Add domain_id to complementary data.""" new_transition = transition.copy() comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - if comp is None: - comp = {} - else: - comp = comp.copy() + comp = {} if comp is None else comp.copy() # Infer batch size from observation tensors obs = new_transition.get(TransitionKey.OBSERVATION, {}) diff --git a/src/lerobot/policies/xvla/transformer.py b/src/lerobot/policies/xvla/transformer.py index 3e43b446e..77ceb6e26 100644 --- a/src/lerobot/policies/xvla/transformer.py +++ b/src/lerobot/policies/xvla/transformer.py @@ -23,7 +23,7 @@ from typing import Final import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional # ------------------------------- Small utils ---------------------------------- @@ -38,7 +38,7 @@ def _to_2tuple(x) -> tuple: def _has_sdp_attention() -> bool: """Check if we can use PyTorch fused scaled_dot_product_attention.""" - return hasattr(F, "scaled_dot_product_attention") + return hasattr(functional, "scaled_dot_product_attention") # ---------------------------------- MLP -------------------------------------- @@ -127,38 +127,38 @@ class Attention(nn.Module): """ Parameters ---------- - x : Tensor, shape [B, T, C] + x : Tensor, shape [batch_size, seq_len, channels] Input sequence. Returns ------- - Tensor, shape [B, T, C] + Tensor, shape [batch_size, seq_len, channels] Output sequence after MHSA + projection. """ - B, T, C = x.shape + batch_size, seq_len, channels = x.shape qkv = ( self.qkv(x) - .reshape(B, T, 3, self.num_heads, self.head_dim) - .permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh] + .reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim] ) - q, k, v = qkv.unbind(0) # each: [B, H, T, Dh] + q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim] q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: - x = F.scaled_dot_product_attention( + x = functional.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, - ) # [B, H, T, Dh] + ) # [batch_size, num_heads, seq_len, head_dim] else: q = q * self.scale - attn = q @ k.transpose(-2, -1) # [B, H, T, T] + attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = attn @ v # [B, H, T, Dh] + x = attn @ v # [batch_size, num_heads, seq_len, head_dim] - x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C] + x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels] x = self.proj(x) x = self.proj_drop(x) return x @@ -240,17 +240,17 @@ class DomainAwareLinear(nn.Module): Returns ------- Tensor - [B, O] or [B, T, O] + [batch_size, output_size] or [batch_size, seq_len, output_size] """ - B = domain_id.shape[0] - squeeze_T = False + batch_size = domain_id.shape[0] + squeeze_seq = False if x.dim() == 2: x = x.unsqueeze(1) - squeeze_T = True - W = self.fc(domain_id).view(B, self.input_size, self.output_size) - b = self.bias(domain_id).view(B, self.output_size) - y = torch.matmul(x, W) + b.view(B, 1, self.output_size) - if squeeze_T: + squeeze_seq = True + weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size) + bias = self.bias(domain_id).view(batch_size, self.output_size) + y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size) + if squeeze_seq: y = y.squeeze(1) return y @@ -370,16 +370,16 @@ class SoftPromptedTransformer(nn.Module): Returns ------- Tensor - Predicted actions, [B, T_action, dim_action] + Predicted actions, [batch_size, num_actions, dim_action] """ - B, num_actions = action_with_noise.shape[:2] + batch_size, num_actions = action_with_noise.shape[:2] # Encode (action + proprio + time) → tokens - time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time] - time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time) - proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1]) + time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time] + time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time) + proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1]) action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1) - x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H] + x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size] # Project visual streams and concatenate if self.use_hetero_proj: @@ -402,7 +402,9 @@ class SoftPromptedTransformer(nn.Module): # Append soft prompts if self.len_soft_prompts > 0: - soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size) + soft_prompts = self.soft_prompt_hub(domain_id).view( + batch_size, self.len_soft_prompts, self.hidden_size + ) x = torch.cat([x, soft_prompts], dim=1) # Transformer backbone diff --git a/src/lerobot/policies/xvla/utils.py b/src/lerobot/policies/xvla/utils.py index abd02edcc..38e3e1f20 100644 --- a/src/lerobot/policies/xvla/utils.py +++ b/src/lerobot/policies/xvla/utils.py @@ -1,5 +1,5 @@ import numpy as np -import robosuite.utils.transform_utils as T +import robosuite.utils.transform_utils as transform_utils def rotate6d_to_axis_angle(r6d): @@ -26,12 +26,12 @@ def rotate6d_to_axis_angle(r6d): # b3 b3 = np.cross(b1, b2, axis=-1) - R = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3) + rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3) axis_angle_list = [] - for i in range(R.shape[0]): - quat = T.mat2quat(R[i]) - axis_angle = T.quat2axisangle(quat) + for i in range(rotation_matrix.shape[0]): + quat = transform_utils.mat2quat(rotation_matrix[i]) + axis_angle = transform_utils.quat2axisangle(quat) axis_angle_list.append(axis_angle) axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 607fc7bc9..0cc6e037f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -51,8 +51,6 @@ from lerobot.utils.utils import ( init_logging, ) -# login to hf - def update_policy( train_metrics: MetricsTracker,