mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
fix(pi05): cast attention masks to model dtype
Ensure attention masks follow the backbone dtype during bf16 inference to avoid mixed dtype failures. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -617,10 +617,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
if dtype is not None:
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
@@ -756,7 +759,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
@@ -814,7 +817,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
prefix_att_2d_masks, dtype=prefix_embs.dtype
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -884,7 +889,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
full_att_2d_masks, dtype=suffix_embs.dtype
|
||||
)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
|
||||
@@ -630,7 +630,9 @@ class PI052Policy(PI05Policy):
|
||||
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
|
||||
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks)
|
||||
att_2d_masks_4d = self.model._prepare_attention_masks_4d(
|
||||
att_2d_masks, dtype=prefix_embs.dtype
|
||||
)
|
||||
|
||||
# ---- forward (capture BOTH expert outputs) ------------------
|
||||
(prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward(
|
||||
@@ -740,7 +742,7 @@ class PI052Policy(PI05Policy):
|
||||
|
||||
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=full_embs.dtype)
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
@@ -864,9 +866,7 @@ class PI052Policy(PI05Policy):
|
||||
for _ in range(max_new_tokens):
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
|
||||
if att_2d_4d.dtype != backbone_dtype:
|
||||
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
|
||||
(vlm_out, _), _ = backbone.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -272,6 +272,8 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
if causal_mask is not None and torch.is_floating_point(causal_mask):
|
||||
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
Reference in New Issue
Block a user