From 28b86449a2e19550edc27da0ae6ce339cc4febc7 Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 21 May 2026 09:51:12 +0000 Subject: [PATCH] fix(pi05): cast attention masks to model dtype Ensure attention masks follow the backbone dtype during bf16 inference to avoid mixed dtype failures. Co-authored-by: Cursor --- src/lerobot/policies/pi05/modeling_pi05.py | 17 ++++++++++++----- src/lerobot/policies/pi052/modeling_pi052.py | 10 +++++----- src/lerobot/policies/pi_gemma.py | 2 ++ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 56786fbcd..4f3dce1c9 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -617,10 +617,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) return func(*args, **kwargs) - def _prepare_attention_masks_4d(self, att_2d_masks): + def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] - return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + if dtype is not None: + result = result.to(dtype=dtype) + return result def sample_noise(self, shape, device): return torch.normal( @@ -756,7 +759,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( @@ -814,7 +817,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d( + prefix_att_2d_masks, dtype=prefix_embs.dtype + ) self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( @@ -884,7 +889,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + full_att_2d_masks_4d = self._prepare_attention_masks_4d( + full_att_2d_masks, dtype=suffix_embs.dtype + ) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 past_key_values = copy.deepcopy(past_key_values) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index b1146ea1a..e587eca4e 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -630,7 +630,9 @@ class PI052Policy(PI05Policy): att_2d_masks[:, fast_end:, fast_start:fast_end] = False position_ids = torch.cumsum(pad_masks, dim=1) - 1 - att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks) + att_2d_masks_4d = self.model._prepare_attention_masks_4d( + att_2d_masks, dtype=prefix_embs.dtype + ) # ---- forward (capture BOTH expert outputs) ------------------ (prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward( @@ -740,7 +742,7 @@ class PI052Policy(PI05Policy): att_2d = make_att_2d_masks(full_pad, full_att) position_ids = torch.cumsum(full_pad, dim=1) - 1 - att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=full_embs.dtype) (vlm_out, _), _ = self.model.paligemma_with_expert.forward( attention_mask=att_2d_4d, @@ -864,9 +866,7 @@ class PI052Policy(PI05Policy): for _ in range(max_new_tokens): att_2d = make_att_2d_masks(current_pad, current_att) position_ids = torch.cumsum(current_pad, dim=1) - 1 - att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) - if att_2d_4d.dtype != backbone_dtype: - att_2d_4d = att_2d_4d.to(dtype=backbone_dtype) + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype) (vlm_out, _), _ = backbone.forward( attention_mask=att_2d_4d, position_ids=position_ids, diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py index 05f031d08..ea5a3abcd 100644 --- a/src/lerobot/policies/pi_gemma.py +++ b/src/lerobot/policies/pi_gemma.py @@ -272,6 +272,8 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc] # Convert to bfloat16 if the first layer uses bfloat16 if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.bfloat16) + if causal_mask is not None and torch.is_floating_point(causal_mask): + causal_mask = causal_mask.to(dtype=hidden_states.dtype) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids)