adhere to python 3.11 syntax

This commit is contained in:
Pepijn
2025-09-10 20:20:31 +02:00
parent 8178a06b90
commit 8c5fe10d6c
@@ -355,8 +355,9 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: None
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
position_embeddings: (
None | tuple[torch.Tensor, torch.Tensor]
) = None, # necessary, but kept here for BC
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: