mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
Update src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
This commit is contained in:
committed by
Steven Palma
parent
6260ffa159
commit
0b957bb4b8
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.10")
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user