## One-sentence answer > `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**. Everything else you’ve seen so far was just metadata. --- ## What goes in ### Inputs ```python prefix_pad_masks # shape [B, L] prefix_att_masks # shape [B, L] ``` Where: * `prefix_pad_masks[b, i] = True` → token `i` exists (not padding) * `prefix_att_masks[b, i] = False` → token `i` is **bidirectional** * `prefix_att_masks[b, i] = True` → token `i` is **causal (autoregressive)** --- ## What comes out ```python att_2d_prefix # shape [B, L, L] ``` Each entry: ```text att_2d_prefix[b, i, j] = True ``` means: > “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.” --- ## How it is constructed (conceptually) For **each batch b**, **each query position i**, **each key position j**: ```python if not prefix_pad_masks[b, j]: att[b, i, j] = False # cannot attend to padding else if not prefix_att_masks[b, i]: att[b, i, j] = True # bidirectional token → can see all real tokens else: att[b, i, j] = (j <= i) # causal token → can see only past + itself ``` That’s it. --- ## Tiny concrete example (exactly matching your code) Suppose: ```python prefix_pad_masks[0] = [T, T, T, T, T, F] prefix_att_masks[0] = [F, F, F, T, T, T] ``` Tokens: ``` 0: IMG 1: IMG 2: LANG 3: SUB0 4: SUB1 5: PAD ``` --- ### Resulting `att_2d_prefix[0]` `✓ = True, ✗ = False` | Q \ K | 0 | 1 | 2 | 3 | 4 | 5 | | ---------- | - | - | - | - | - | - | | 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | | 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | | 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | | 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | | 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | | 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | --- ## Why this matters for your training code This line: ```python att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix) ``` Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`. This is **exactly what Paligemma uses inside self-attention**. --- ## Key implications (VERY important) ### 1️⃣ This mask does **not isolate token groups** * Bidirectional tokens can attend to **everything** * Causal tokens only restrict *their own row* So **flow/action tokens must be blocked separately**. --- ### 2️⃣ This is why your AR subtask prediction works * Subtask tokens are causal * Output at position `i` predicts token `i+1` * Padding is fully ignored --- ### 3️⃣ Inference behavior When `subtask_tokens = None`: * `prefix_att_masks` contains only `False` * `att_2d_prefix` becomes **fully bidirectional** * No AR behavior remains Exactly what you want. --- ## One-sentence takeaway (commit this) > `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses. If you want next, I can: * inspect `make_att_2d_masks()` source with you * show how to block **flow → subtask** attention * explain how this changes when suffix tokens are added * help you refactor this into a cleaner “grouped attention” API You’re now at the point where the model’s behavior should feel *predictable*, not magical.