3.4 KiB
One-sentence answer
make_att_2d_masks(prefix_pad_masks, prefix_att_masks)builds the actual 2D attention mask[B, L, L]that tells the transformer which token positions may attend to which others, combining padding and causality.
Everything else you’ve seen so far was just metadata.
What goes in
Inputs
prefix_pad_masks # shape [B, L]
prefix_att_masks # shape [B, L]
Where:
-
prefix_pad_masks[b, i] = True→ tokeniexists (not padding) -
prefix_att_masks[b, i] = False→ tokeniis bidirectional -
prefix_att_masks[b, i] = True→ tokeniis causal (autoregressive)
What comes out
att_2d_prefix # shape [B, L, L]
Each entry:
att_2d_prefix[b, i, j] = True
means:
“In batch
b, token i (query) is allowed to attend to token j (key).”
How it is constructed (conceptually)
For each batch b, each query position i, each key position j:
if not prefix_pad_masks[b, j]:
att[b, i, j] = False # cannot attend to padding
else if not prefix_att_masks[b, i]:
att[b, i, j] = True # bidirectional token → can see all real tokens
else:
att[b, i, j] = (j <= i) # causal token → can see only past + itself
That’s it.
Tiny concrete example (exactly matching your code)
Suppose:
prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]
Tokens:
0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD
Resulting att_2d_prefix[0]
✓ = True, ✗ = False
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|---|---|---|---|---|---|---|
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
Why this matters for your training code
This line:
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
Converts [B, L, L] → [B, 1, L, L] and possibly flips True/False to 0/-inf.
This is exactly what Paligemma uses inside self-attention.
Key implications (VERY important)
1️⃣ This mask does not isolate token groups
- Bidirectional tokens can attend to everything
- Causal tokens only restrict their own row
So flow/action tokens must be blocked separately.
2️⃣ This is why your AR subtask prediction works
- Subtask tokens are causal
- Output at position
ipredicts tokeni+1 - Padding is fully ignored
3️⃣ Inference behavior
When subtask_tokens = None:
prefix_att_maskscontains onlyFalseatt_2d_prefixbecomes fully bidirectional- No AR behavior remains
Exactly what you want.
One-sentence takeaway (commit this)
make_att_2d_masksfuses padding and causality into a concrete[B, L, L]attention matrix that the transformer actually uses.
If you want next, I can:
- inspect
make_att_2d_masks()source with you - show how to block flow → subtask attention
- explain how this changes when suffix tokens are added
- help you refactor this into a cleaner “grouped attention” API
You’re now at the point where the model’s behavior should feel predictable, not magical.