Files
lerobot/tests/policies/pi052/test_pi052_sdpa_attention.py
T
pepijn 4913356564 pi052: SDPA attention port + selective AC + bench harness
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call
with ``torch.nn.functional.scaled_dot_product_attention`` in
``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052).
PyTorch SDPA picks the memory-efficient kernel for the
block-bidirectional 4D additive mask the dual-expert model uses (FA2 /
FA3 reject it because they only accept causal / sliding-window / varlen
patterns). The shared ``sdpa_attention_forward`` helper mirrors the
eager signature so the call sites are unchanged.

Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)``
wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside
``PaliGemmaWithExpertModel.forward`` already handles activation
recompute; the outer wrap was double-recomputing the whole backbone.
+14% steps/sec on its own (job 22161405 vs 22161398, 1xH100).

groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub``
rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses,
which was blocking imports of any sibling policy through
``lerobot.policies.factory``.

New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8)
times PI052Policy.forward+backward (optionally with AdamW) on
synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True,
L=512, 4.14 B trainable params, AdamW state in bf16:

  pre-SDPA eager BS=8                 610ms   19.5 GiB  ->  13.1 samples/s
  sdpa  BS=8  + compile=default       413ms   19.5 GiB  ->  19.3 samples/s
  sdpa  BS=16 + compile=default       715ms   37.3 GiB  ->  22.4 samples/s
  sdpa  BS=32 + compile=default      1325ms   44.8 GiB  ->  24.2 samples/s
  sdpa  BS=40 + compile=default      1665ms   48.6 GiB  ->  24.0 samples/s

Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py``
cover fp32 / bf16 / GQA / MHA forward + backward — output and grads
match the eager path within bf16 tolerance.

Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate
config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the
follow-up multi-GPU memory sharding work.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 21:59:20 +00:00

156 lines
5.8 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Numerical-parity tests for the SDPA attention port.
``pi05`` / ``pi052`` replaced the per-layer call from
``modeling_gemma.eager_attention_forward`` with
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
output must be bit-equivalent (within bf16 tolerance) on the masks
this model actually uses — block-bidirectional with an arbitrary
additive bias — otherwise we silently change training behaviour.
"""
from types import SimpleNamespace
import pytest
import torch
pytest.importorskip("transformers")
from transformers.models.gemma import modeling_gemma # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
make_att_2d_masks,
sdpa_attention_forward,
)
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
def _mock_self_attn(num_kv_groups: int, training: bool = False):
"""Bare module surface that both forwards read."""
return SimpleNamespace(
num_key_value_groups=num_kv_groups,
training=training,
)
def _build_inputs(
bsize: int,
num_heads: int,
num_kv_heads: int,
seq_len: int,
head_dim: int,
dtype: torch.dtype,
seed: int = 0,
):
g = torch.Generator(device="cpu").manual_seed(seed)
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
return q, k, v
def _block_bidirectional_mask(
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
) -> torch.Tensor:
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
matches ``[images, language, suffix]`` from ``embed_prefix`` +
``embed_suffix``: every block bidirectional internally, later
blocks visible to earlier ones via the cumulative-block rule.
"""
assert sum(block_sizes) == seq_len
att_marks = []
for i, n in enumerate(block_sizes):
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
att_2d = make_att_2d_masks(pad, att)
bias = torch.where(
att_2d[:, None, :, :],
torch.zeros((), dtype=dtype),
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
)
return bias
@pytest.mark.parametrize(
"num_heads,num_kv_heads,head_dim",
[
(8, 1, 256), # gemma_2b / paligemma config
(8, 8, 64), # MHA control (no GQA repeat)
],
)
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
"""SDPA forward output matches the eager softmax(QK^T)@V on the
block-bidirectional mask layout pi05 actually uses."""
bsize, seq_len = 2, 13
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
assert out_eager.shape == out_sdpa.shape
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
def test_sdpa_parity_bf16():
"""bf16 path — looser tolerance, must still match eager."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
def test_sdpa_parity_backward():
"""Gradients flow through SDPA and match the eager path within
bf16 tolerance — critical for any training-side parity claim."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
module = _mock_self_attn(num_heads // num_kv_heads)
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)