From 4ad41f7a766eec5d7c9ac05e4fcc6b1ee6bb8fda Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 28 Nov 2025 10:16:11 +0100 Subject: [PATCH] iterate on review --- pyproject.toml | 2 +- src/lerobot/policies/xvla/action_hub.py | 1 - .../policies/xvla/modeling_florence2.py | 118 +----------------- src/lerobot/policies/xvla/modeling_xvla.py | 9 +- src/lerobot/policies/xvla/processor_xvla.py | 7 +- 5 files changed, 10 insertions(+), 127 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c71bc45fb..14d5d9053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] -xlva = ["lerobot[transformers-dep]"] +xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py index 80be6847d..15c0813ed 100644 --- a/src/lerobot/policies/xvla/action_hub.py +++ b/src/lerobot/policies/xvla/action_hub.py @@ -378,7 +378,6 @@ class BimanualSO101ActionSpace(BaseActionSpace): left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6]) right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12]) - # is gripper continuous? not bce? gripper_loss = ( self.mse( pred[:, :, [5, 11]], diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index 49a5e9c84..280b13abb 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -50,10 +50,11 @@ from transformers.utils import ( replace_return_docstrings, ) -from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig +from .configuration_florence2 import Florence2Config, Florence2LanguageConfig from .utils import drop_path if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa logger = logging.get_logger(__name__) @@ -665,11 +666,6 @@ class DaViT(nn.Module): ) -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -2435,116 +2431,6 @@ FLORENCE2_INPUTS_DOCSTRING = r""" """ -@add_start_docstrings( - """The FLORENCE2 vision model without any head""", - FLORENCE2_START_DOCSTRING, -) -class Florence2VisionModel(Florence2PreTrainedModel): - def __init__(self, config: Florence2VisionConfig): - super().__init__(config) - assert config.model_type == "davit", "only DaViT is supported for now" - self.vision_tower = DaViT.from_config(config=config) - - self.post_init() - - def forward(self, pixel_values): - if len(pixel_values.shape) == 4: - x = self.vision_tower.forward_features_unpool(pixel_values) - else: - raise ValueError(f"invalid image shape {pixel_values.shape}") - return x - - -@add_start_docstrings( - """The FLORENCE2 vision model with projection layer""", - FLORENCE2_START_DOCSTRING, -) -class Florence2VisionModelWithProjection(Florence2PreTrainedModel): - def __init__(self, config: Florence2VisionConfig): - super().__init__(config) - assert config.model_type == "davit", "only DaViT is supported for now" - self.vision_tower = DaViT.from_config(config=config) - - self._build_image_projection_layers(config) - - self.post_init() - - def _build_image_projection_layers(self, config): - image_dim_out = config.dim_embed[-1] - dim_projection = config.projection_dim - self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection)) - self.image_proj_norm = nn.LayerNorm(dim_projection) - image_pos_embed_config = config.image_pos_embed - if image_pos_embed_config["type"] == "learned_abs_2d": - self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"] - ) - else: - raise NotImplementedError("Not implemented yet") - - self.image_feature_source = config.image_feature_source - - # temporal embedding - visual_temporal_embedding_config = config.visual_temporal_embedding - if visual_temporal_embedding_config["type"] == "COSINE": - self.visual_temporal_embed = PositionalEmbeddingCosine1D( - embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"], - ) - else: - raise NotImplementedError("Not implemented yet") - - def forward(self, pixel_values): - if len(pixel_values.shape) == 4: - batch_size, channels, height, width = pixel_values.shape - num_frames = 1 - x = self.vision_tower.forward_features_unpool(pixel_values) - else: - raise ValueError(f"invalid image shape {pixel_values.shape}") - - if self.image_pos_embed is not None: - x = x.view(batch_size * num_frames, -1, x.shape[-1]) - num_tokens = x.shape[-2] - h, w = int(num_tokens**0.5), int(num_tokens**0.5) - assert h * w == num_tokens, "only support square feature maps for now" - x = x.view(batch_size * num_frames, h, w, x.shape[-1]) - pos_embed = self.image_pos_embed(x) - x = x + pos_embed - x = x.view(batch_size, num_frames * h * w, x.shape[-1]) - - if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0] - ) - x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view( - 1, num_frames, 1, x.shape[-1] - ) - - x_feat_dict = {} - - spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2) - x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x - - temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1) - x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x - - x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1] - x_feat_dict["last_frame"] = x - - new_x = [] - for _image_feature_source in self.image_feature_source: - if _image_feature_source not in x_feat_dict: - raise ValueError(f"invalid image feature source: {_image_feature_source}") - new_x.append(x_feat_dict[_image_feature_source]) - - x = torch.cat(new_x, dim=1) - - x = x @ self.image_projection - x = self.image_proj_norm(x) - - return x - - @add_start_docstrings( """The FLORENCE2 model which consists of a vision backbone and a language model.""", FLORENCE2_START_DOCSTRING, diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index bc691ac04..8175f86eb 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -19,6 +19,7 @@ from __future__ import annotations import builtins +import logging import os from collections import deque from pathlib import Path @@ -352,7 +353,7 @@ class XVLAPolicy(PreTrainedPolicy): feature = self.config.action_feature if feature is None: return actions - desired_dim = feature.shape[0] + desired_dim = self.model.dim_action if desired_dim == actions.shape[-1]: return actions if desired_dim < actions.shape[-1]: @@ -434,7 +435,7 @@ class XVLAPolicy(PreTrainedPolicy): instance = cls(config, **kwargs) # step 2: locate model.safetensors if os.path.isdir(model_id): - print("Loading weights from local directory") + logging.info("Loading weights from local directory") model_file = os.path.join(model_id, "model.safetensors") else: try: @@ -455,7 +456,7 @@ class XVLAPolicy(PreTrainedPolicy): except HfHubHTTPError as e: raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e - print(f"Loading checkpoint from {model_file}") + logging.info(f"Loading checkpoint from {model_file}") # step 3: load state dict state_dict = safetensors.torch.load_file(model_file) encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight" @@ -465,7 +466,7 @@ class XVLAPolicy(PreTrainedPolicy): # or deepcopy # step 4: load into instance instance.load_state_dict(state_dict, strict=True) - print("Loaded XVLA checkpoint") + logging.info("Loaded XVLA checkpoint") # step 5: finalize # Reapply dtype after loading state dict instance.model._apply_dtype() diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 374c9f4ac..1c813634e 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -70,8 +70,8 @@ def make_xvla_pre_post_processors( ), XVLAImageToFloatProcessorStep(), XVLAImageNetNormalizeProcessorStep(), - DeviceProcessorStep(device=config.device), XVLAAddDomainIdProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features=features, norm_map=config.normalization_mapping, stats=dataset_stats ), @@ -426,11 +426,9 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep): Args: domain_id: The domain ID to add (default: 3) - device: Device to place the domain_id tensor on (default: "cuda") """ domain_id: int = 0 - device: str = "cuda" def __call__(self, transition: EnvTransition) -> EnvTransition: """Add domain_id to complementary data.""" @@ -448,7 +446,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep): break # Add domain_id tensor - comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long).to(self.device) + comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long) new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp return new_transition @@ -461,7 +459,6 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep): """Return serializable configuration.""" return { "domain_id": self.domain_id, - "device": self.device, }