diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md index b43be409e..78548bd8d 100644 --- a/src/lerobot/policies/wall_x/README.md +++ b/src/lerobot/policies/wall_x/README.md @@ -6,12 +6,13 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag ## Model Overview -| Feature | Description | -| -------------------- | ------------------------------------------------------------------------ | -| Base Model | Qwen2.5-VL (Vision-Language Model) | -| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | -| Architecture | Mixture of Experts (MoE) with action-specific routing | | -| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | +| Feature | Description | +| ------------------ | ----------------------------------------------------- | --- | +| Base Model | Qwen2.5-VL (Vision-Language Model) | +| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | + --- ## Citation @@ -32,4 +33,3 @@ If you use this work, please cite: ## License This port follows the **Apache 2.0 License**. - diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index 456ac993e..0d10a8f98 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -49,7 +49,7 @@ class WallXConfig(PreTrainedConfig): } ) - # ==================== Action Prediction ==================== + # ==================== Action Prediction ==================== # Pretrained model paths pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" @@ -85,20 +85,16 @@ class WallXConfig(PreTrainedConfig): ) if self.prediction_mode not in ["diffusion", "fast"]: - raise ValueError( - f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}" - ) + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") # Assign use_fast_tokenizer based on prediction_mode if self.prediction_mode == "fast": self.use_fast_tokenizer = True elif self.prediction_mode == "diffusion": self.use_fast_tokenizer = False - self.action_tokenizer_path = None # disable action tokenizer for diffusion mode + self.action_tokenizer_path = None # disable action tokenizer for diffusion mode else: - raise ValueError( - f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}" - ) + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") def validate_features(self) -> None: """Validate and set up input/output features.""" diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py index 597d24951..43e5e7fb6 100644 --- a/src/lerobot/policies/wall_x/constant.py +++ b/src/lerobot/policies/wall_x/constant.py @@ -38,4 +38,4 @@ PRIORITY_ORDER = None GENERATE_SUBTASK_RATIO = 0.0 MODEL_TYPE = "qwen2_5" -TOKENIZER_MAX_LENGTH = 768 \ No newline at end of file +TOKENIZER_MAX_LENGTH = 768 diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 16175127d..c401c8d60 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -34,61 +34,57 @@ lerobot-train \ ``` """ - import math -from os import PathLike from collections import deque -from typing import Any, Dict, List, Optional, Tuple, Union -from PIL import Image +from os import PathLike +from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from peft import LoraConfig, get_peft_model +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize from torch import Tensor from torch.distributions import Beta from torch.nn import CrossEntropyLoss from torchdiffeq import odeint -from transformers import AutoProcessor +from transformers import AutoProcessor, BatchFeature from transformers.cache_utils import ( StaticCache, ) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, +) from transformers.utils import is_torchdynamo_compiling, logging -from transformers import AutoProcessor, BatchFeature -from qwen_vl_utils.vision_process import smart_resize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import populate_queues from lerobot.policies.wall_x.configuration_wall_x import WallXConfig -from lerobot.utils.constants import ACTION, OBS_STATE - -from lerobot.policies.wall_x.utils import ( - replace_action_token, - preprocesser_call, - get_wallx_normal_text, - process_grounding_points, -) from lerobot.policies.wall_x.constant import ( - MODEL_TYPE, - TOKENIZER_MAX_LENGTH, - PRIORITY_ORDER, GENERATE_SUBTASK_RATIO, - RESOLUTION, + IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS, - IMAGE_FACTOR, + MODEL_TYPE, + PRIORITY_ORDER, + RESOLUTION, + TOKENIZER_MAX_LENGTH, ) from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration, -) - from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLACausalLMOutputWithPast, Qwen2_5_VLMoEModel, ) +from lerobot.policies.wall_x.utils import ( + get_wallx_normal_text, + preprocesser_call, + process_grounding_points, + replace_action_token, +) +from lerobot.utils.constants import ACTION, OBS_STATE logger = logging.get_logger(__name__) @@ -151,7 +147,7 @@ class ActionHead(nn.Module): """Sample timesteps using Beta distribution (always in float32 for numerical stability).""" beta_dist = Beta( torch.tensor(self.beta_alpha, dtype=torch.float32, device=device), - torch.tensor(self.beta_beta, dtype=torch.float32, device=device) + torch.tensor(self.beta_beta, dtype=torch.float32, device=device), ) sample = beta_dist.sample([batch_size]) time = (1 - sample) * self.s @@ -204,7 +200,7 @@ class ActionHead(nn.Module): def step(self, timestep, noisy_action, dof_mask=None): """Single denoising step for inference.""" weight_dtype = self.w1.weight.dtype - + if dof_mask is not None: noisy_action = torch.cat([noisy_action, dof_mask], dim=-1) noisy_action = noisy_action.to(dtype=weight_dtype) @@ -226,7 +222,7 @@ class ActionHead(nn.Module): # Ensure all inputs are float32 action_hidden_states = action_hidden_states.to(torch.float32) flow = flow.to(torch.float32) - + action_pred = self.action_proj_back(action_hidden_states) loss = F.mse_loss(action_pred, flow, reduction="none") @@ -275,14 +271,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): pretrained_name_or_path, config=None, action_tokenizer_path=None, - attn_implementation: str = 'eager', + attn_implementation: str = "eager", cache_dir: str | PathLike | None = None, force_download: bool = False, local_files_only: bool = False, token: str | bool | None = None, revision: str = "main", strict: bool = False, - **kwargs: Any + **kwargs: Any, ): """ Load model from pretrained model path. @@ -312,9 +308,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): config._attn_implementation = attn_implementation processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) if action_tokenizer_path is not None: - action_tokenizer = AutoProcessor.from_pretrained( - action_tokenizer_path, trust_remote_code=True - ) + action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True) processor.action_processor = action_tokenizer else: action_tokenizer = None @@ -387,9 +381,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): super().__init__(config) # Initialize vision transformer and language model components - self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( - config.vision_config - ) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLMoEModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -432,7 +424,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): params_to_keep_float32.append(name) if "action_preprocessor" in name: params_to_keep_float32.append(name) - + for name, param in self.named_parameters(): if name in params_to_keep_float32: param.data = param.data.to(torch.float32) @@ -446,12 +438,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Create list of fast action token IDs fast_action_token_list = [] if self.use_fast_tokenizer: - for i in range( - self.processor.tokenizer.init_kwargs["action_token_vocab_size"] - ): - action_token_id = self.processor.tokenizer.convert_tokens_to_ids( - f"<|action_token_{i}|>" - ) + for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]): + action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>") fast_action_token_list.append(action_token_id) # Get special action token IDs @@ -465,9 +453,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): "action_token_id": action_token_id, } - def add_lora( - self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1 - ): + def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1): """ Add LoRA (Low-Rank Adaptation) adapters to the model. @@ -516,12 +502,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): def get_rope_index( self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens. @@ -555,9 +541,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] - if input_ids is not None and ( - image_grid_thw is not None or video_grid_thw is not None - ): + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -580,9 +564,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): image_nums, video_nums = 0, 0 # Find vision tokens and count images/videos - vision_start_indices = torch.argwhere( - input_ids == vision_start_token_id - ).squeeze(1) + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -641,14 +623,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): text_len = ed - st # Add position IDs for text tokens before vision token - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) # Calculate 3D position embeddings for vision tokens range_tensor = torch.arange(llm_grid_t).view(-1, 1) @@ -656,71 +632,43 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Calculate temporal position IDs with time scaling time_tensor = ( - expanded_range - * second_per_grid_t - * self.config.vision_config.tokens_per_second + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second ) time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() # Calculate spatial position IDs h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() ) w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() ) # Add 3D position IDs for vision tokens - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w # Add position IDs for remaining text tokens if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) # Concatenate all position IDs for this sequence llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) - mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) - ) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: # Handle case without vision tokens - use standard 1D position embeddings if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = ( - position_ids.unsqueeze(0) - .expand(3, -1, -1) - .to(attention_mask.device) - ) - max_position_ids = position_ids.max(0, keepdim=False)[0].max( - -1, keepdim=True - )[0] + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( @@ -739,33 +687,29 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): def train_step_forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - moe_token_types: Optional[ - torch.LongTensor - ] = None, # MoE token type assignments - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - action_chunk: Optional[torch.FloatTensor] = None, # Action trajectory chunks - proprioception: Optional[ - torch.FloatTensor - ] = None, # Joint position/orientation data - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - dof_mask: Optional[torch.FloatTensor] = None, - agent_pos_mask: Optional[torch.FloatTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, # MoE token type assignments + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks + proprioception: torch.FloatTensor | None = None, # Joint position/orientation data + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, **kwargs, - ) -> Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: + ) -> tuple | Qwen2_5_VLACausalLMOutputWithPast: """ Forward pass for training with multi-modal inputs including vision, text, and action data. @@ -806,24 +750,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Set output configuration from model config if not specified output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Calculate RoPE position IDs if not provided # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation - if position_ids is None and ( - attention_mask is None or attention_mask.ndim == 2 - ): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # Calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) @@ -865,9 +801,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # Process video embeddings @@ -887,19 +821,13 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) - video_embeds = video_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # Process proprioceptive data (joint positions, orientations, etc.) if proprioception is not None: - proprioception = proprioception.to(inputs_embeds.device).to( - inputs_embeds.dtype - ) - agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to( - inputs_embeds.dtype - ) + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) proprioception = self.action_preprocessor.proprioception_proj( proprioception, agent_pos_mask, @@ -910,12 +838,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) proprioception_mask = mask_expanded.to(inputs_embeds.device) - proprioception = proprioception.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - proprioception_mask, proprioception - ) + proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception) elif self.training: # Dummy forward pass to ensure gradient registration in DDP # This handles cases where one process has proprioception data while another doesn't @@ -925,32 +849,22 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): self.action_preprocessor.propri_dim * 2, device=inputs_embeds.device, ) - dummy_forward = self.action_preprocessor.proprioception_proj( - dummy_input - ) + dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input) dummy_loss = sum(p.sum() for p in dummy_forward) inputs_embeds = inputs_embeds + 0 * dummy_loss # Process action chunk data if action_chunk is not None: - action_chunk = action_chunk.to(inputs_embeds.device).to( - inputs_embeds.dtype - ) + action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype) dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) - noisy_action_emb, flow = self.action_preprocessor( - action_chunk, dof_mask - ) + noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask) mask = input_ids == self.action_token_id_set["action_token_id"] mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) action_mask = mask_expanded.to(inputs_embeds.device) - noisy_action_emb = noisy_action_emb.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - action_mask, noisy_action_emb - ) + noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) @@ -1011,18 +925,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): if action_mask.any(): action_hidden_states = hidden_states[action_mask].to(torch.float32) flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32) - _flow_loss = self.action_preprocessor.flow_loss( - action_hidden_states, flow, dof_mask - ) + _flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask) if isinstance(_flow_loss, torch.Tensor): flow_loss = _flow_loss.mean() if loss is not None: loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32) else: loss = self.flow_loss_weight * flow_loss.to(torch.float32) - _flow_loss = _flow_loss.view( - dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2] - ) + _flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]) # Return outputs based on return_dict setting if not return_dict: @@ -1031,9 +941,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): return Qwen2_5_VLACausalLMOutputWithPast( loss=loss, - cross_entropy_loss=( - cross_entropy_loss.clone() if cross_entropy_loss is not None else None - ), + cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None), flow_loss=flow_loss, logits=logits, past_key_values=outputs.past_key_values, @@ -1065,31 +973,31 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): def predict( self, predict_mode: str, - pred_horizon: Optional[int] = None, - action_dim: Optional[int] = None, + pred_horizon: int | None = None, + action_dim: int | None = None, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - moe_token_types: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - action_chunk: Optional[torch.FloatTensor] = None, - proprioception: Optional[torch.FloatTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - num_inference_timesteps: Optional[int] = 10, - dof_mask: Optional[torch.FloatTensor] = None, - agent_pos_mask: Optional[torch.FloatTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, + proprioception: torch.FloatTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + num_inference_timesteps: int | None = 10, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, re_generate: bool = False, **kwargs, ): @@ -1139,30 +1047,20 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): - 'predict_output_text': Generated text (for text/fast modes) - 'gt_output_text': Ground truth text (for text/fast modes) """ - batch_size = ( - input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - ) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # Text and fast modes require batch size 1 for autoregressive generation if predict_mode in ["text", "fast"]: - assert ( - batch_size == 1 - ), "predict only support batch size 1 for ar generation" + assert batch_size == 1, "predict only support batch size 1 for ar generation" # Set output configuration from model config if not specified output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Process input embeddings with multi-modal data if inputs_embeds is None: @@ -1186,9 +1084,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # Process video embeddings @@ -1209,40 +1105,28 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) - video_embeds = video_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # Process proprioceptive data if proprioception is not None: - proprioception = proprioception.to(inputs_embeds.device).to( - inputs_embeds.dtype - ) - agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to( - inputs_embeds.dtype - ) + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) proprio_embed = self.action_preprocessor.proprioception_proj( proprioception, agent_pos_mask, use_history=proprioception.shape[1] > 1, ) - proprioception_mask = ( - input_ids == self.action_token_id_set["propri_token_id"] - ) + proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"] proprio_embed = proprio_embed.to(torch.bfloat16) - inputs_embeds[proprioception_mask] = proprio_embed.reshape( - -1, inputs_embeds.shape[-1] - ) + inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1]) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # Calculate RoPE position IDs if not provided # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation - if position_ids is None and ( - attention_mask is None or attention_mask.ndim == 2 - ): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # Calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) @@ -1332,12 +1216,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): eos_token_id=[self.processor.tokenizer.eos_token_id], use_cache=True, pad_token_id=self.processor.tokenizer.pad_token_id, - temperature=( - 1.0 if not re_generate else 0.7 - ), # Higher temperature for regeneration - do_sample=( - False if not re_generate else True - ), # Enable sampling for regeneration + temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration + do_sample=(False if not re_generate else True), # Enable sampling for regeneration ) # Decode generated and ground truth text @@ -1359,15 +1239,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): action_id = [] # Extract action tokens from generated sequence for token_id_i in predict_output_ids[0]: - if ( - token_id_i.item() - >= self.processor.tokenizer.init_kwargs["action_token_start_index"] - ): + if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]: action_id.append( - token_id_i.item() - - self.processor.tokenizer.init_kwargs[ - "action_token_start_index" - ] + token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"] ) predict_action = self.processor.action_processor.decode( @@ -1382,9 +1256,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): predict_action = torch.tensor(predict_action, device=self.device) dof_mask = dof_mask.to(self.device).to(pixel_values.dtype) # removed unnormalization step for now - predict_action = ( - predict_action[:, :, dof_mask[0, 0, :].bool()] - ) + predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()] output["predict_action"] = predict_action # Process ground truth actions if available @@ -1426,7 +1298,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): timestep=timestep, noisy_action=noisy_action, dof_mask=dof_mask ) action_embed = action_embed.reshape(-1, inputs_embeds.shape[-1]) - + # Ensure action_embed has the correct dtype and device before assignment action_embed = action_embed.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) @@ -1459,7 +1331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): times = torch.linspace( 0, 1, - num_inference_timesteps+1, + num_inference_timesteps + 1, device=inputs_embeds.device, dtype=torch.float32, ) @@ -1477,9 +1349,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): return output - def forward( - self, mode: Optional[str] = None, predict_mode: Optional[str] = "text", **kwargs - ): + def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs): """ Main forward pass dispatcher for different execution modes. @@ -1592,9 +1462,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Handle input slicing based on cache state and special cases if past_key_values is not None: - if ( - inputs_embeds is not None and input_ids.shape[1] == 0 - ): # Exception 4: input_embeds case + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] moe_token_types = moe_token_types[:, -cache_position.shape[0] :] elif inputs_embeds is not None or ( # Exception 1: input_embeds provided @@ -1602,9 +1470,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): ): # Exception 3: GPU sync edge case input_ids = input_ids[:, -cache_position.shape[0] :] moe_token_types = moe_token_types[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (Exception 2 is no-op) + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op) cache_pos = cache_position.clone() input_ids = input_ids[:, cache_pos] moe_token_types = moe_token_types[:, cache_pos] @@ -1629,18 +1495,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): batch_size, sequence_length = input_ids.shape device = input_ids.device - attention_mask = ( - self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) # Assemble all model inputs for generation @@ -1666,8 +1530,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): def _get_image_nums_and_video_nums( self, - input_ids: Optional[torch.LongTensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate tensor separation lengths. @@ -1702,9 +1566,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): self, expand_size: int = 1, is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + ) -> tuple[torch.LongTensor, dict[str, Any]]: """ Expand inputs for generation with support for multi-modal tensors. @@ -1745,9 +1609,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): """Split tensor by lengths and repeat each sample.""" samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) - result = torch.cat( - [sample.repeat(*repeat_args) for sample in samples], dim=0 - ) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) return result for key in dict_to_expand: @@ -1785,9 +1647,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): ) tensor = torch.tensor(dict_to_expand[key]) lengths = list(video_nums) - tensor = _repeat_interleave_samples( - tensor, lengths=lengths, repeat_times=expand_size - ) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) dict_to_expand[key] = tensor.tolist() return dict_to_expand @@ -1800,9 +1660,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): and isinstance(dict_to_expand[key], torch.Tensor) and key not in visual_keys ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave( - expand_size, dim=0 - ) + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand # Expand visual inputs only if input_ids is available for counting images/videos @@ -1823,9 +1681,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): raise ValueError( "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." ) - model_kwargs["encoder_outputs"] = _expand_dict_for_generation( - model_kwargs["encoder_outputs"] - ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs @@ -1848,9 +1704,9 @@ class WallXPolicy(PreTrainedPolicy): # Initialize the wall-x model self.model = Qwen2_5_VLMoEForAction.from_pretrained( - pretrained_name_or_path=config.pretrained_name_or_path, + pretrained_name_or_path=config.pretrained_name_or_path, action_tokenizer_path=config.action_tokenizer_path, - attn_implementation=config.attn_implementation + attn_implementation=config.attn_implementation, ) self.model.to(config.device) self.model.to_bfloat16_for_selected_params() @@ -1869,49 +1725,49 @@ class WallXPolicy(PreTrainedPolicy): def preprocess_inputs( self, - batch: Dict[str, Any], + batch: dict[str, Any], ) -> BatchFeature: """ Convert a batch of LeRobot dataset items to Wall-X model input format. - + This processes a batched dictionary where tensors have batch dimension first. - + Args: batch: Dictionary with batched tensors: - "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim) - "action": (batch_size, chunk_size, action_dim) - "observation.images.": (batch_size, C, H, W) - "task": List[str] of length batch_size - + Returns: BatchFeature containing batched model inputs """ use_fast_tokenizer = self.config.use_fast_tokenizer - + # Get batch size from state tensor batch_size = batch[OBS_STATE].shape[0] - + # ==================== PROCESS ALL SAMPLES ==================== all_image_inputs = [] all_texts = [] - + # Find image keys in batch img_keys = [key for key in self.config.image_features if key in batch] - + for i in range(batch_size): # Vision preprocessing per sample processed_frames = [] orig_height, orig_width = None, None resized_height, resized_width = None, None - + for key in img_keys: current_obs = batch[key][i].clone() # (C, H, W) if current_obs.dim() == 3: current_obs = current_obs.permute(1, 2, 0) # (H, W, C) - + img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) orig_width, orig_height = img_pil.size - + target_size = RESOLUTION if target_size != -1: if orig_width > orig_height: @@ -1921,7 +1777,7 @@ class WallXPolicy(PreTrainedPolicy): new_height = target_size new_width = int(target_size * orig_width / orig_height) img_pil = img_pil.resize((new_width, new_height)) - + current_width, current_height = img_pil.size resized_height, resized_width = smart_resize( current_height, @@ -1932,13 +1788,13 @@ class WallXPolicy(PreTrainedPolicy): ) resized_img = img_pil.resize((resized_width, resized_height)) processed_frames.append(resized_img) - + all_image_inputs.append(processed_frames) - + # Text preprocessing task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"] instruction_info = {"instruction": task_text} - + frame_index = batch["frame_index"][i] if "frame_index" in batch else 0 complete_text, _ = get_wallx_normal_text( instruction_info, @@ -1948,57 +1804,76 @@ class WallXPolicy(PreTrainedPolicy): img_keys, generate_subtask_ratio=GENERATE_SUBTASK_RATIO, ) - + text = process_grounding_points( - complete_text, orig_height, orig_width, resized_height, resized_width, - MODEL_TYPE + complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE ) all_texts.append(text) - # ==================== PROCESS AGENT POS ==================== agent_pos = batch[OBS_STATE] # (batch_size, state_dim) if agent_pos.dim() == 2: agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim) agent_pos_mask = (~torch.isnan(agent_pos)).float() agent_pos = agent_pos.nan_to_num(nan=0.0) - + if agent_pos.shape[-1] != 20: pad_size = 20 - agent_pos.shape[-1] - agent_pos = torch.cat([ - agent_pos, - torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device) - ], dim=-1) - agent_pos_mask = torch.cat([ - agent_pos_mask, - torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device) - ], dim=-1) - + agent_pos = torch.cat( + [ + agent_pos, + torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device), + ], + dim=-1, + ) + agent_pos_mask = torch.cat( + [ + agent_pos_mask, + torch.zeros( + agent_pos_mask.shape[0], + agent_pos_mask.shape[1], + pad_size, + device=agent_pos_mask.device, + ), + ], + dim=-1, + ) + # ==================== PROCESS ACTIONS ==================== - action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim) + action = batch.get(ACTION) # (batch_size, chunk_size, action_dim) if action is not None: if action.dim() == 2: action = action.unsqueeze(1) dof_mask = (~torch.isnan(action)).float() action = action.nan_to_num(nan=0.0) - + if action.shape[-1] != 20: pad_size = 20 - action.shape[-1] - action = torch.cat([ - action, - torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device) - ], dim=-1) - dof_mask = torch.cat([ - dof_mask, - torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) - ], dim=-1) + action = torch.cat( + [action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)], + dim=-1, + ) + dof_mask = torch.cat( + [ + dof_mask, + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device), + ], + dim=-1, + ) else: action_dim = self.config.output_features["action"].shape[0] - dof_mask = torch.cat([ - torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device), - torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device) - ], dim=-1) - + dof_mask = torch.cat( + [ + torch.ones( + batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device + ), + torch.zeros( + batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device + ), + ], + dim=-1, + ) + # ==================== ACTION TOKEN REPLACEMENT ==================== all_texts = replace_action_token( all_texts, @@ -2006,7 +1881,7 @@ class WallXPolicy(PreTrainedPolicy): self.model.action_tokenizer if use_fast_tokenizer else None, dof_mask, ) - + # ==================== TOKENIZATION ==================== inputs = preprocesser_call( processor=self.model.processor, @@ -2018,24 +1893,28 @@ class WallXPolicy(PreTrainedPolicy): return_tensors="pt", max_length=TOKENIZER_MAX_LENGTH, ) - + # ==================== ADDITIONAL INPUTS ==================== action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>") moe_token_types = inputs.input_ids == action_token_id - + inputs["proprioception"] = agent_pos inputs["agent_pos_mask"] = agent_pos_mask inputs["action_chunk"] = action inputs["dof_mask"] = dof_mask inputs["moe_token_types"] = moe_token_types - inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device) - + inputs["frame_index"] = ( + batch["frame_index"] + if "frame_index" in batch + else torch.zeros(batch_size, device=batch[OBS_STATE].device) + ) + # Move all tensors to the correct device device = self.config.device for key, value in inputs.items(): if isinstance(value, torch.Tensor): inputs[key] = value.to(device) - + return inputs def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: @@ -2052,14 +1931,11 @@ class WallXPolicy(PreTrainedPolicy): tuple: (loss, loss_dict) """ batch = self.preprocess_inputs( - batch, + batch, ) # Call the underlying model's forward with mode="train" - outputs = self.model( - **batch, - mode="train" - ) + outputs = self.model(**batch, mode="train") # Extract losses from output loss = outputs.loss @@ -2087,7 +1963,7 @@ class WallXPolicy(PreTrainedPolicy): self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) batch = self.preprocess_inputs( - batch, + batch, ) if self.config.prediction_mode == "diffusion": @@ -2096,7 +1972,7 @@ class WallXPolicy(PreTrainedPolicy): action_dim=self.config.max_action_dim, pred_horizon=self.config.chunk_size, mode="predict", - predict_mode="diffusion" + predict_mode="diffusion", ) elif self.config.prediction_mode == "fast": output = self.model( @@ -2104,14 +1980,14 @@ class WallXPolicy(PreTrainedPolicy): action_dim=self.config.output_features["action"].shape[0], pred_horizon=self.config.chunk_size, mode="predict", - predict_mode="fast" + predict_mode="fast", ) else: raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") # Extract action tensor from output dictionary actions = output["predict_action"] - + # Unpad actions to actual action dimension action_dim = self.config.output_features["action"].shape[0] actions = actions[:, :, :action_dim] @@ -2127,6 +2003,6 @@ class WallXPolicy(PreTrainedPolicy): # Use action queue if len(self._queues[ACTION]) == 0: actions = self.predict_action_chunk(batch) - self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps]) + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) return self._queues[ACTION].popleft() diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py index d8ad402ed..e4e281541 100644 --- a/src/lerobot/policies/wall_x/processor_wall_x.py +++ b/src/lerobot/policies/wall_x/processor_wall_x.py @@ -33,6 +33,8 @@ from lerobot.processor import ( ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + def make_wall_x_pre_post_processors( config: WallXConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, @@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors( output_steps = [ UnnormalizerProcessorStep( - features=config.output_features, - norm_map=config.normalization_mapping, - stats=dataset_stats + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), DeviceProcessorStep(device="cpu"), ] @@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep): new_complementary_data["task"] = f"{task}." elif isinstance(task, list) and all(isinstance(t, str) for t in task): # List of strings: format each - new_complementary_data["task"] = [ - t if t.endswith(".") else f"{t}." for t in task - ] + new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task] return new_complementary_data diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index a21b0b348..2a3e5eac5 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -1,11 +1,12 @@ import math +from dataclasses import dataclass +from typing import Any + import torch import torch.nn as nn import torch.nn.functional as F -from dataclasses import dataclass from torch.nn import CrossEntropyLoss -from typing import Any, Dict, List, Optional, Tuple, Union - +from transformers import AutoConfig from transformers.activations import ACT2FN from transformers.cache_utils import ( Cache, @@ -13,7 +14,6 @@ from transformers.cache_utils import ( SlidingWindowCache, StaticCache, ) -from transformers import AutoConfig from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput @@ -28,12 +28,12 @@ from transformers.utils import ( logging, replace_return_docstrings, ) + from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb - from flash_attn import flash_attn_func else: flash_attn_varlen_func = None apply_rotary_emb = None @@ -50,6 +50,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + class Qwen2_5_VLMLP(nn.Module): def __init__(self, config, bias: bool = False): super().__init__() @@ -61,9 +62,7 @@ class Qwen2_5_VLMLP(nn.Module): self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class Qwen2_5_VisionPatchEmbed(nn.Module): @@ -98,9 +97,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): self.patch_size, self.patch_size, ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( - -1, self.embed_dim - ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states @@ -111,9 +108,7 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module): self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs @@ -156,7 +151,7 @@ class Qwen2_5_VLPatchMerger(nn.Module): def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) @@ -175,16 +170,13 @@ class Qwen2_5_VLVisionFlashAttention2(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: Optional[int] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = ( - self.qkv(hidden_states) - .reshape(seq_length, 3, self.num_heads, -1) - .permute(1, 0, 2, 3) - .unbind(0) + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) if position_embeddings is None: logger.warning_once( @@ -204,9 +196,9 @@ class Qwen2_5_VLVisionFlashAttention2(nn.Module): if max_seqlen is None: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func( - q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen - ).reshape(seq_length, -1) + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) attn_output = self.proj(attn_output) return attn_output @@ -220,7 +212,7 @@ def rotate_half(x): def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() @@ -244,16 +236,13 @@ class Qwen2_5_VLVisionAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: Optional[int] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = ( - self.qkv(hidden_states) - .reshape(seq_length, 3, self.num_heads, -1) - .permute(1, 0, 2, 3) - .unbind(0) + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) if position_embeddings is None: logger.warning_once( @@ -287,9 +276,7 @@ class Qwen2_5_VLVisionAttention(nn.Module): v = v.transpose(0, 1) attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(q.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) @@ -308,16 +295,13 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: Optional[int] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = ( - self.qkv(hidden_states) - .reshape(seq_length, 3, self.num_heads, -1) - .permute(1, 0, 2, 3) - .unbind(0) + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) if position_embeddings is None: logger.warning_once( @@ -333,9 +317,7 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module): cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - attention_mask = torch.zeros( - [1, seq_length, seq_length], device=q.device, dtype=torch.bool - ) + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[ ..., @@ -345,9 +327,7 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module): q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention( - q, k, v, attention_mask, dropout_p=0.0 - ) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) @@ -375,9 +355,9 @@ class Qwen2_5_VLVisionBlock(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - max_seqlen: Optional[int] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -406,6 +386,7 @@ Qwen2_5_VL_START_DOCSTRING = r""" [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ + @add_start_docstrings( "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", Qwen2_5_VL_START_DOCSTRING, @@ -419,7 +400,9 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + _supports_static_cache = ( + False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + ) def _init_weights(self, module): std = self.config.initializer_range @@ -456,10 +439,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( - [ - Qwen2_5_VLVisionBlock(config, config._attn_implementation) - for _ in range(config.depth) - ] + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] ) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, @@ -501,18 +481,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 - vit_merger_window_size = ( - self.window_size // self.spatial_merge_size // self.patch_size - ) + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w - ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size @@ -535,18 +511,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = ( - seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - ) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens - def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -564,13 +536,12 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_window_seqlens = torch.tensor( cu_window_seqlens, device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,) + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 - ) + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape( @@ -581,9 +552,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum( + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 @@ -593,9 +562,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen_full = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - max_seqlen_window = ( - (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() - ) + max_seqlen_window = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -632,9 +599,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -664,8 +629,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module): self.max_seq_len_cached = seq_len if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -677,25 +641,13 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module): # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = ( - self.inv_freq[None, None, :, None] - .float() - .expand(3, position_ids.shape[1], -1, 1) - ) - position_ids_expanded = position_ids[ - :, :, None, : - ].float() # shape (3, bs, 1, positions) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(2, 3) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -756,12 +708,12 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ mrope_section = mrope_section * 2 - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 - ).unsqueeze(unsqueeze_dim) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 - ).unsqueeze(unsqueeze_dim) + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -776,9 +728,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -788,7 +738,7 @@ class Qwen2_5_VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int | None = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -813,34 +763,24 @@ class Qwen2_5_VLAttention(nn.Module): f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=True - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -870,9 +810,7 @@ class Qwen2_5_VLAttention(nn.Module): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -886,12 +824,8 @@ class Qwen2_5_VLAttention(nn.Module): ) # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -931,15 +865,13 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # necessary, but kept here for BC + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -1027,16 +959,14 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -1127,40 +1057,29 @@ class Qwen2_5_VLDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - if ( - config.use_sliding_window - and config._attn_implementation != "flash_attention_2" - ): + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation]( - config, layer_idx - ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # necessary, but kept here for BC + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1227,14 +1146,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [ - Qwen2_5_VLDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1253,36 +1167,28 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: @@ -1299,9 +1205,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -1310,9 +1214,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): # the hard coded `3` is for temporal, height and width. if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand( - 3, inputs_embeds.shape[0], -1 - ) + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) @@ -1380,9 +1282,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if not return_dict: return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -1401,9 +1301,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: - is_padding_right = ( - attention_mask[:, -1].sum().item() != input_tensor.size()[0] - ) + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" @@ -1417,9 +1315,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1474,9 +1370,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -1527,36 +1421,29 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): dtype=dtype, device=device, ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange( - target_length, device=device - ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ - :, None, None, : - ].to(causal_mask.device) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) return causal_mask @@ -1591,12 +1478,12 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): The rope index difference between sequence length and multimodal rope. """ - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None QWEN2_5_VL_INPUTS_DOCSTRING = r""" @@ -1682,9 +1569,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def __init__(self, config): super().__init__(config) - self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( - config.vision_config - ) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1713,12 +1598,12 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def get_rope_index( self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. @@ -1777,9 +1662,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] - if input_ids is not None and ( - image_grid_thw is not None or video_grid_thw is not None - ): + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -1795,9 +1678,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere( - input_ids == vision_start_token_id - ).squeeze(1) + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -1845,78 +1726,44 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi ) text_len = ed - st - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) time_tensor = ( - expanded_range - * second_per_grid_t - * self.config.vision_config.tokens_per_second + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second ) time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() ) w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() ) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) - mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) - ) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = ( - position_ids.unsqueeze(0) - .expand(3, -1, -1) - .to(attention_mask.device) - ) - max_position_ids = position_ids.max(0, keepdim=False)[0].max( - -1, keepdim=True - )[0] + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( @@ -1933,29 +1780,27 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi return position_ids, mrope_position_deltas @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1997,18 +1842,12 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi ```""" output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) @@ -2027,9 +1866,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: @@ -2047,18 +1884,14 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) - video_embeds = video_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and ( - attention_mask is None or attention_mask.ndim == 2 - ): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) @@ -2187,18 +2020,16 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi batch_size, sequence_length = input_ids.shape device = input_ids.device - attention_mask = ( - self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) model_inputs.update( @@ -2219,8 +2050,8 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def _get_image_nums_and_video_nums( self, - input_ids: Optional[torch.LongTensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. @@ -2250,9 +2081,9 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi self, expand_size: int = 1, is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + ) -> tuple[torch.LongTensor, dict[str, Any]]: # Overwritten -- Support for expanding tensors without a batch size dimension # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t # pixel_values.shape[0] is sum(seqlen_images for samples) @@ -2277,9 +2108,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) - result = torch.cat( - [sample.repeat(*repeat_args) for sample in samples], dim=0 - ) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) return result for key in dict_to_expand: @@ -2315,9 +2144,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi ) tensor = torch.tensor(dict_to_expand[key]) lengths = list(video_nums) - tensor = _repeat_interleave_samples( - tensor, lengths=lengths, repeat_times=expand_size - ) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) dict_to_expand[key] = tensor.tolist() return dict_to_expand @@ -2329,9 +2156,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi and isinstance(dict_to_expand[key], torch.Tensor) and key not in visual_keys ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave( - expand_size, dim=0 - ) + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand # input_ids is required for expanding visual inputs @@ -2349,25 +2174,24 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi raise ValueError( "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." ) - model_kwargs["encoder_outputs"] = _expand_dict_for_generation( - model_kwargs["encoder_outputs"] - ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs + @dataclass class Qwen2_5_VLACausalLMOutputWithPast(ModelOutput): - loss: Optional[torch.FloatTensor] = None - flow_loss: Optional[torch.FloatTensor] = None - cross_entropy_loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None + loss: torch.FloatTensor | None = None + flow_loss: torch.FloatTensor | None = None + cross_entropy_loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None - channel_loss_dict: Optional[dict[torch.FloatTensor]] = None - channel_loss_count_dict: Optional[dict[torch.FloatTensor]] = None + channel_loss_dict: dict[torch.FloatTensor] | None = None + channel_loss_count_dict: dict[torch.FloatTensor] | None = None class BlockSparseMLP(nn.Module): @@ -2383,9 +2207,7 @@ class BlockSparseMLP(nn.Module): self.act_fn = ACT2FN[self.hidden_act] def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class SparseMoeBlock(nn.Module): @@ -2394,7 +2216,7 @@ class SparseMoeBlock(nn.Module): self.num_experts = num_experts self.experts = nn.ModuleList([BlockSparseMLP(config.experts[i]) for i in range(num_experts)]) - if not hasattr(config, 'dim_inputs') or not config.dim_inputs: + if not hasattr(config, "dim_inputs") or not config.dim_inputs: raise ValueError("Config must contain valid dim_inputs") self.dim_inputs = config.dim_inputs @@ -2405,7 +2227,7 @@ class SparseMoeBlock(nn.Module): Args: hidden_states (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). - experts_indices (torch.Tensor): Tensor of shape (batch_size, seq_length), + experts_indices (torch.Tensor): Tensor of shape (batch_size, seq_length), indicating the expert index assigned to each token. Returns: @@ -2415,7 +2237,7 @@ class SparseMoeBlock(nn.Module): output = torch.zeros_like(hidden_states) for expert_idx, expert in enumerate(self.experts): - mask = (experts_indices == expert_idx) + mask = experts_indices == expert_idx if mask.sum() == 0: continue dim_input = self.dim_inputs[expert_idx] @@ -2441,23 +2263,16 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): super().__init__() self.hidden_size = config.hidden_size - if ( - config.use_sliding_window - and config._attn_implementation != "flash_attention_2" - ): + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation]( - config, layer_idx - ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.mlp_moe: self.moe = SparseMoeBlock(config, num_experts=num_experts) @@ -2468,18 +2283,16 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, token_types=None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -2524,9 +2337,7 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): hidden_states = self.post_attention_layernorm(hidden_states) if self.mlp is None: # using moe mlp hidden_states = hidden_states.to(self.moe.experts[0].down_proj.weight.dtype) - hidden_states = self.moe( - hidden_states, token_types - ) + hidden_states = self.moe(hidden_states, token_types) else: hidden_states = hidden_states.to(self.mlp.down_proj.weight.dtype) hidden_states = self.mlp(hidden_states) @@ -2541,6 +2352,7 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): outputs += (present_key_value,) return outputs + class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. @@ -2552,7 +2364,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): def from_pretrained( cls, pretrained_model_name_or_path: str, - num_experts: Optional[int] = None, + num_experts: int | None = None, *args, **kwargs, ): @@ -2567,7 +2379,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): Returns: Initialized model instance with MoE configuration """ - config = kwargs.get("config", None) + config = kwargs.get("config") if config is None: config = AutoConfig.from_pretrained(pretrained_model_name_or_path) @@ -2591,9 +2403,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): self.vocab_size = config.vocab_size # Model components - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # Decoder layers with MoE support self.layers = nn.ModuleList( @@ -2630,40 +2440,32 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - moe_token_types: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: # Set default output options output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Validate inputs if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if moe_token_types is None: raise ValueError("moe_token_types must be provided for MoE routing") @@ -2686,9 +2488,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): # Set up cache position if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -2697,9 +2497,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand( - 3, inputs_embeds.shape[0], -1 - ) + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) @@ -2778,9 +2576,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): # Return outputs in requested format if not return_dict: return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( @@ -2797,7 +2593,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, - moe_token_types: Optional[torch.LongTensor] = None, + moe_token_types: torch.LongTensor | None = None, ): """Update causal attention mask with support for bidirectional attention for specific token types. @@ -2822,9 +2618,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): return None # Calculate sequence lengths for cache management - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -2879,24 +2673,18 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): # Modify mask to support bidirectional attention for specific token types if moe_token_types is not None: # Identify positions of type 1 tokens (MoE routing tokens) - type1_tokens = ( - (moe_token_types == 1).unsqueeze(1).unsqueeze(2) - ) # Shape: [B, 1, 1, S] + type1_tokens = (moe_token_types == 1).unsqueeze(1).unsqueeze(2) # Shape: [B, 1, 1, S] # Create bidirectional attention region for type 1 tokens # This allows type 1 tokens to attend to each other bidirectionally type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] - type1_region = type1_tokens & type1_tokens.transpose( - -1, -2 - ) # Shape: [B, 1, S, S] + type1_region = type1_tokens & type1_tokens.transpose(-1, -2) # Shape: [B, 1, S, S] type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) # Apply bidirectional attention: zero out causal constraints in type 1 regions causal_mask = torch.where( type1_mask, # Where type 1 tokens interact with each other - torch.zeros_like( - causal_mask - ), # Remove causal masking (allow bidirectional) + torch.zeros_like(causal_mask), # Remove causal masking (allow bidirectional) causal_mask, # Keep original causal masking for other regions ) @@ -2910,9 +2698,7 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): # Ensure attention to all tokens in fully masked rows for memory-efficient attention # This is required for F.scaled_dot_product_attention's memory-efficient path # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -2963,42 +2749,36 @@ class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): dtype=dtype, device=device, ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange( - target_length, device=device - ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ - :, None, None, : - ].to(causal_mask.device) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) return causal_mask + __all__ = [ "Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLMoEModel", -] \ No newline at end of file +] diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py index 19d85aa66..bada4ebdf 100644 --- a/src/lerobot/policies/wall_x/utils.py +++ b/src/lerobot/policies/wall_x/utils.py @@ -25,7 +25,7 @@ import random import re from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import torch from transformers import BatchFeature @@ -46,11 +46,11 @@ class X2RDataProcessingConfig: """ # Action prediction configuration - predict_action_keys: List[str] = field(default_factory=list) - obs_action_keys: List[str] = field(default_factory=list) + predict_action_keys: list[str] = field(default_factory=list) + obs_action_keys: list[str] = field(default_factory=list) # Image resolution settings for different views - resolution: Dict[str, int] = field( + resolution: dict[str, int] = field( default_factory=lambda: { "face_view": -1, "left_wrist_view": 128, @@ -63,7 +63,7 @@ class X2RDataProcessingConfig: split_seed: int = 42 # Instruction handling - priority_order: Optional[Dict[str, float]] = None + priority_order: dict[str, float] | None = None # Vision model parameters model_type: str = "qwen2_5" @@ -77,11 +77,9 @@ class X2RDataProcessingConfig: """Post-initialization validation and setup.""" # Validate train/test split if not 0 < self.train_test_split < 1: - raise ValueError( - f"train_test_split must be between 0 and 1, got {self.train_test_split}" - ) + raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}") - def as_dict(self) -> Dict: + def as_dict(self) -> dict: """Convert configuration to dictionary format. Returns: @@ -105,14 +103,15 @@ class X2RDataProcessingConfig: raise ValueError(f"Unknown configuration parameter: {key}") return self + def preprocesser_call( processor, - images: Optional[Union[List, Any]] = None, - text: Optional[Union[str, List[str]]] = None, - videos: Optional[Union[List, Any]] = None, - padding: Union[bool, str] = False, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, + images: list | Any | None = None, + text: str | list[str] | None = None, + videos: list | Any | None = None, + padding: bool | str = False, + truncation: bool | None = None, + max_length: int | None = None, return_tensors: str = "pt", ) -> BatchFeature: """Unified preprocessing function for Wall-X model handling text, image and video inputs. @@ -145,9 +144,7 @@ def preprocesser_call( """ # Process image inputs if images is not None and len(images) > 0: - image_inputs = processor.image_processor( - images=images, videos=None, return_tensors=return_tensors - ) + image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors) image_grid_thw = image_inputs["image_grid_thw"] else: image_inputs = {} @@ -155,9 +152,7 @@ def preprocesser_call( # Process video inputs if videos is not None: - videos_inputs = processor.image_processor( - images=None, videos=videos, return_tensors=return_tensors - ) + videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors) video_grid_thw = videos_inputs["video_grid_thw"] else: videos_inputs = {} @@ -183,9 +178,7 @@ def preprocesser_call( break # Replace image placeholder with actual token count token_count = image_grid_thw[index].prod() // merge_length - text[i] = text[i].replace( - "<|image_pad|>", "<|placeholder|>" * token_count, 1 - ) + text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") @@ -197,9 +190,7 @@ def preprocesser_call( while "<|video_pad|>" in text[i]: # Replace video placeholder with actual token count token_count = video_grid_thw[index].prod() // merge_length - text[i] = text[i].replace( - "<|video_pad|>", "<|placeholder|>" * token_count, 1 - ) + text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") @@ -221,9 +212,7 @@ def preprocesser_call( labels = torch.full_like(text_inputs.input_ids, -100) assistant_marker = "<|im_start|>assistant\n" im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") - assistant_tokens = processor.tokenizer( - "<|im_start|>assistant\n", add_special_tokens=False - ).input_ids + assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids for i in range(len(text)): assistant_regions = [] @@ -249,9 +238,7 @@ def preprocesser_call( # From second part onwards, each part starts with assistant response for k in range(current_pos + 1, len(text_inputs.input_ids[i])): if text_inputs.input_ids[i][k] == im_end_token_id: - assistant_regions.append( - (current_pos + len(assistant_tokens), k + 2) - ) + assistant_regions.append((current_pos + len(assistant_tokens), k + 2)) break current_pos += len(part_tokens) + 3 @@ -344,7 +331,7 @@ def process_grounding_points( coords = [new_x1, new_y1, new_x2, new_y2] # Return processed point tag - return f'[{", ".join(map(str, coords))}]' + return f"[{', '.join(map(str, coords))}]" except (ValueError, TypeError): # Return original content if processing fails @@ -356,10 +343,10 @@ def process_grounding_points( def get_frame_instruction( - instruction_info: Dict[str, Any], - frame_idx: Optional[int] = None, - truncate_keys: Optional[List[str]] = None, -) -> Tuple[Dict[str, Any], Optional[int]]: + instruction_info: dict[str, Any], + frame_idx: int | None = None, + truncate_keys: list[str] | None = None, +) -> tuple[dict[str, Any], int | None]: """Extract frame-specific instruction from instruction dictionary. Args: @@ -388,11 +375,7 @@ def get_frame_instruction( start_frame, end_frame = map(int, frame_range.split(" ")) if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): instruction_for_frame[key] = frame_instruction - if ( - truncate_keys is not None - and split_end is None - and key in truncate_keys - ): + if truncate_keys is not None and split_end is None and key in truncate_keys: split_end = end_frame + 1 break else: @@ -402,7 +385,7 @@ def get_frame_instruction( def get_task_instruction( - frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None + frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None ) -> str: """Construct task instruction from available instruction fields using priority sampling. @@ -450,13 +433,13 @@ def get_task_instruction( def get_wallx_normal_text( - instruction_info: Dict[str, Any], + instruction_info: dict[str, Any], action_chunk_size: int, frame_idx: int, - priority_order: Optional[OrderedDict] = None, - img_keys: Optional[List[str]] = None, + priority_order: OrderedDict | None = None, + img_keys: list[str] | None = None, generate_subtask_ratio: float = 0.0, -) -> Tuple[str, bool]: +) -> tuple[str, bool]: """Construct complete multimodal prompt text for Wall-X model. Formats input using special tokens including: @@ -488,9 +471,7 @@ def get_wallx_normal_text( action_fast_symbol = "<|action_fast|>" # System prologue - prologue = ( - f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" - ) + prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" # User request with observation user_request = f"{role_start_symbol}user\nObservation:" @@ -501,9 +482,7 @@ def get_wallx_normal_text( user_request += "\nInstruction:" # Get frame-specific instruction - frame_instruction_info, _ = get_frame_instruction( - instruction_info, frame_idx=frame_idx - ) + frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx) generate_subtask = False priority_keys = ["subtask_generation", "distribute"] @@ -524,15 +503,11 @@ def get_wallx_normal_text( output_instruction = frame_instruction_info[key] break - assistant_output = ( - f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}" - ) + assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}" generate_subtask = True else: # Generate actions - instruction = get_task_instruction( - frame_instruction_info, priority_order=priority_order - ) + instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order) text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" @@ -540,7 +515,8 @@ def get_wallx_normal_text( complete_text = prologue + user_message + assistant_output return complete_text, generate_subtask -def img_key_mapping(img_keys: List[str]) -> List[str]: + +def img_key_mapping(img_keys: list[str]) -> list[str]: """Map image keys to camera names. Args: @@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]: if key in CAMERA_NAME_MAPPING: key = CAMERA_NAME_MAPPING[key] else: - if 'view' in key: - key = key.replace('_', ' ') + if "view" in key: + key = key.replace("_", " ") else: key = key + " view" processed_img_keys.append(key) return processed_img_keys -def get_action_tokens( - normalized_actions: Union[torch.Tensor, List], action_tokenizer -) -> List[List[str]]: + +def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]: """Convert normalized actions to action token strings. Args: @@ -590,8 +565,8 @@ def get_action_tokens( def pad_action_token_strs( - actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>" -) -> List[str]: + actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>" +) -> list[str]: """Pad action token lists to same length and join as strings. Args: @@ -605,20 +580,18 @@ def pad_action_token_strs( padded_action_strs = [] for tokens in actions_token_lists: - padded_tokens = ( - tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens)) - ) + padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens)) padded_action_strs.append("".join(padded_tokens)) return padded_action_strs def replace_action_token( - text: List[str], - norm_action: Optional[torch.Tensor], + text: list[str], + norm_action: torch.Tensor | None, action_tokenizer, - dof_masks: Optional[torch.Tensor] = None, -) -> List[str]: + dof_masks: torch.Tensor | None = None, +) -> list[str]: """Replace action placeholders in text with actual action tokens. Args: @@ -632,10 +605,7 @@ def replace_action_token( """ if action_tokenizer is not None and norm_action is not None: # Extract actions based on chunk sizes and DOF masks - norm_action = [ - action[: 32, dof_masks[i, 0].bool()] - for i, action in enumerate(norm_action) - ] + norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)] # Convert to action tokens and pad actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) @@ -658,4 +628,3 @@ def replace_action_token( text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] return text - diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index 8286655f0..fca6686c4 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402 ) from lerobot.utils.random_utils import set_seed # noqa: E402 + def test_policy_instantiation(): # Create config set_seed(42) - config = WallXConfig(device='cuda') + config = WallXConfig(device="cuda") # Set up input_features and output_features in the config from lerobot.configs.types import FeatureType, PolicyFeature @@ -118,6 +119,7 @@ def test_policy_instantiation(): print(f"Action prediction failed: {e}") raise + def test_config_creation(): """Test policy config creation through factory.""" try: @@ -130,6 +132,7 @@ def test_config_creation(): print(f"Config creation failed: {e}") raise + if __name__ == "__main__": test_policy_instantiation() - test_config_creation() \ No newline at end of file + test_config_creation()