Files
lerobot/examples/dataset/mask.md
T
2025-12-16 11:28:27 +00:00

3.4 KiB
Raw Blame History

One-sentence answer

make_att_2d_masks(prefix_pad_masks, prefix_att_masks) builds the actual 2D attention mask [B, L, L] that tells the transformer which token positions may attend to which others, combining padding and causality.

Everything else youve seen so far was just metadata.


What goes in

Inputs

prefix_pad_masks   # shape [B, L]
prefix_att_masks   # shape [B, L]

Where:

  • prefix_pad_masks[b, i] = True → token i exists (not padding)

  • prefix_att_masks[b, i] = False → token i is bidirectional

  • prefix_att_masks[b, i] = True → token i is causal (autoregressive)


What comes out

att_2d_prefix  # shape [B, L, L]

Each entry:

att_2d_prefix[b, i, j] = True

means:

“In batch b, token i (query) is allowed to attend to token j (key).”


How it is constructed (conceptually)

For each batch b, each query position i, each key position j:

if not prefix_pad_masks[b, j]:
    att[b, i, j] = False           # cannot attend to padding
else if not prefix_att_masks[b, i]:
    att[b, i, j] = True            # bidirectional token → can see all real tokens
else:
    att[b, i, j] = (j <= i)        # causal token → can see only past + itself

Thats it.


Tiny concrete example (exactly matching your code)

Suppose:

prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]

Tokens:

0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD

Resulting att_2d_prefix[0]

✓ = True, ✗ = False

Q \ K 0 1 2 3 4 5
0 (bi)
1 (bi)
2 (bi)
3 (causal)
4 (causal)
5 (pad)

Why this matters for your training code

This line:

att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)

Converts [B, L, L] → [B, 1, L, L] and possibly flips True/False to 0/-inf.

This is exactly what Paligemma uses inside self-attention.


Key implications (VERY important)

1️⃣ This mask does not isolate token groups

  • Bidirectional tokens can attend to everything
  • Causal tokens only restrict their own row

So flow/action tokens must be blocked separately.


2️⃣ This is why your AR subtask prediction works

  • Subtask tokens are causal
  • Output at position i predicts token i+1
  • Padding is fully ignored

3️⃣ Inference behavior

When subtask_tokens = None:

  • prefix_att_masks contains only False
  • att_2d_prefix becomes fully bidirectional
  • No AR behavior remains

Exactly what you want.


One-sentence takeaway (commit this)

make_att_2d_masks fuses padding and causality into a concrete [B, L, L] attention matrix that the transformer actually uses.

If you want next, I can:

  • inspect make_att_2d_masks() source with you
  • show how to block flow → subtask attention
  • explain how this changes when suffix tokens are added
  • help you refactor this into a cleaner “grouped attention” API

Youre now at the point where the models behavior should feel predictable, not magical.