diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py index c1b277abe..6ddb6c337 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -355,8 +355,9 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, - position_embeddings: None - | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC + position_embeddings: ( + None | tuple[torch.Tensor, torch.Tensor] + ) = None, # necessary, but kept here for BC adarms_cond: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: