mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
160 lines
3.4 KiB
Markdown
160 lines
3.4 KiB
Markdown
## One-sentence answer
|
||
|
||
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
|
||
|
||
Everything else you’ve seen so far was just metadata.
|
||
|
||
---
|
||
|
||
## What goes in
|
||
|
||
### Inputs
|
||
|
||
```python
|
||
prefix_pad_masks # shape [B, L]
|
||
prefix_att_masks # shape [B, L]
|
||
```
|
||
|
||
Where:
|
||
|
||
* `prefix_pad_masks[b, i] = True`
|
||
→ token `i` exists (not padding)
|
||
|
||
* `prefix_att_masks[b, i] = False`
|
||
→ token `i` is **bidirectional**
|
||
|
||
* `prefix_att_masks[b, i] = True`
|
||
→ token `i` is **causal (autoregressive)**
|
||
|
||
---
|
||
|
||
## What comes out
|
||
|
||
```python
|
||
att_2d_prefix # shape [B, L, L]
|
||
```
|
||
|
||
Each entry:
|
||
|
||
```text
|
||
att_2d_prefix[b, i, j] = True
|
||
```
|
||
|
||
means:
|
||
|
||
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
|
||
|
||
---
|
||
|
||
## How it is constructed (conceptually)
|
||
|
||
For **each batch b**, **each query position i**, **each key position j**:
|
||
|
||
```python
|
||
if not prefix_pad_masks[b, j]:
|
||
att[b, i, j] = False # cannot attend to padding
|
||
else if not prefix_att_masks[b, i]:
|
||
att[b, i, j] = True # bidirectional token → can see all real tokens
|
||
else:
|
||
att[b, i, j] = (j <= i) # causal token → can see only past + itself
|
||
```
|
||
|
||
That’s it.
|
||
|
||
---
|
||
|
||
## Tiny concrete example (exactly matching your code)
|
||
|
||
Suppose:
|
||
|
||
```python
|
||
prefix_pad_masks[0] = [T, T, T, T, T, F]
|
||
prefix_att_masks[0] = [F, F, F, T, T, T]
|
||
```
|
||
|
||
Tokens:
|
||
|
||
```
|
||
0: IMG
|
||
1: IMG
|
||
2: LANG
|
||
3: SUB0
|
||
4: SUB1
|
||
5: PAD
|
||
```
|
||
|
||
---
|
||
|
||
### Resulting `att_2d_prefix[0]`
|
||
|
||
`✓ = True, ✗ = False`
|
||
|
||
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|
||
| ---------- | - | - | - | - | - | - |
|
||
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
|
||
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||
|
||
---
|
||
|
||
## Why this matters for your training code
|
||
|
||
This line:
|
||
|
||
```python
|
||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
|
||
```
|
||
|
||
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
|
||
|
||
This is **exactly what Paligemma uses inside self-attention**.
|
||
|
||
---
|
||
|
||
## Key implications (VERY important)
|
||
|
||
### 1️⃣ This mask does **not isolate token groups**
|
||
|
||
* Bidirectional tokens can attend to **everything**
|
||
* Causal tokens only restrict *their own row*
|
||
|
||
So **flow/action tokens must be blocked separately**.
|
||
|
||
---
|
||
|
||
### 2️⃣ This is why your AR subtask prediction works
|
||
|
||
* Subtask tokens are causal
|
||
* Output at position `i` predicts token `i+1`
|
||
* Padding is fully ignored
|
||
|
||
---
|
||
|
||
### 3️⃣ Inference behavior
|
||
|
||
When `subtask_tokens = None`:
|
||
|
||
* `prefix_att_masks` contains only `False`
|
||
* `att_2d_prefix` becomes **fully bidirectional**
|
||
* No AR behavior remains
|
||
|
||
Exactly what you want.
|
||
|
||
---
|
||
|
||
## One-sentence takeaway (commit this)
|
||
|
||
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
|
||
|
||
If you want next, I can:
|
||
|
||
* inspect `make_att_2d_masks()` source with you
|
||
* show how to block **flow → subtask** attention
|
||
* explain how this changes when suffix tokens are added
|
||
* help you refactor this into a cleaner “grouped attention” API
|
||
|
||
You’re now at the point where the model’s behavior should feel *predictable*, not magical.
|