diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index ecf3eb371..2e034ebd2 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.10") def forward( self, diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index e33efe5c3..1cdeed781 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -45,7 +45,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, logging, replace_return_docstrings, ) @@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.10") def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)