mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
pi052: SDPA attention port + selective AC + bench harness
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call with ``torch.nn.functional.scaled_dot_product_attention`` in ``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052). PyTorch SDPA picks the memory-efficient kernel for the block-bidirectional 4D additive mask the dual-expert model uses (FA2 / FA3 reject it because they only accept causal / sliding-window / varlen patterns). The shared ``sdpa_attention_forward`` helper mirrors the eager signature so the call sites are unchanged. Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)`` wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside ``PaliGemmaWithExpertModel.forward`` already handles activation recompute; the outer wrap was double-recomputing the whole backbone. +14% steps/sec on its own (job 22161405 vs 22161398, 1xH100). groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub`` rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses, which was blocking imports of any sibling policy through ``lerobot.policies.factory``. New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8) times PI052Policy.forward+backward (optionally with AdamW) on synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True, L=512, 4.14 B trainable params, AdamW state in bf16: pre-SDPA eager BS=8 610ms 19.5 GiB -> 13.1 samples/s sdpa BS=8 + compile=default 413ms 19.5 GiB -> 19.3 samples/s sdpa BS=16 + compile=default 715ms 37.3 GiB -> 22.4 samples/s sdpa BS=32 + compile=default 1325ms 44.8 GiB -> 24.2 samples/s sdpa BS=40 + compile=default 1665ms 48.6 GiB -> 24.0 samples/s Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py`` cover fp32 / bf16 / GQA / MHA forward + backward — output and grads match the eager path within bf16 tolerance. Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the follow-up multi-GPU memory sharding work. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Numerical-parity tests for the SDPA attention port.
|
||||
|
||||
``pi05`` / ``pi052`` replaced the per-layer call from
|
||||
``modeling_gemma.eager_attention_forward`` with
|
||||
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
|
||||
output must be bit-equivalent (within bf16 tolerance) on the masks
|
||||
this model actually uses — block-bidirectional with an arbitrary
|
||||
additive bias — otherwise we silently change training behaviour.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.models.gemma import modeling_gemma # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
|
||||
make_att_2d_masks,
|
||||
sdpa_attention_forward,
|
||||
)
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
|
||||
|
||||
|
||||
def _mock_self_attn(num_kv_groups: int, training: bool = False):
|
||||
"""Bare module surface that both forwards read."""
|
||||
return SimpleNamespace(
|
||||
num_key_value_groups=num_kv_groups,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
||||
def _build_inputs(
|
||||
bsize: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
):
|
||||
g = torch.Generator(device="cpu").manual_seed(seed)
|
||||
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def _block_bidirectional_mask(
|
||||
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
|
||||
matches ``[images, language, suffix]`` from ``embed_prefix`` +
|
||||
``embed_suffix``: every block bidirectional internally, later
|
||||
blocks visible to earlier ones via the cumulative-block rule.
|
||||
"""
|
||||
assert sum(block_sizes) == seq_len
|
||||
att_marks = []
|
||||
for i, n in enumerate(block_sizes):
|
||||
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
|
||||
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
|
||||
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
|
||||
att_2d = make_att_2d_masks(pad, att)
|
||||
bias = torch.where(
|
||||
att_2d[:, None, :, :],
|
||||
torch.zeros((), dtype=dtype),
|
||||
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
|
||||
)
|
||||
return bias
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_kv_heads,head_dim",
|
||||
[
|
||||
(8, 1, 256), # gemma_2b / paligemma config
|
||||
(8, 8, 64), # MHA control (no GQA repeat)
|
||||
],
|
||||
)
|
||||
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
|
||||
"""SDPA forward output matches the eager softmax(QK^T)@V on the
|
||||
block-bidirectional mask layout pi05 actually uses."""
|
||||
bsize, seq_len = 2, 13
|
||||
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
|
||||
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
|
||||
scaling = head_dim ** -0.5
|
||||
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
|
||||
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
out_sdpa, _ = sdpa_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
assert out_eager.shape == out_sdpa.shape
|
||||
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
def test_sdpa_parity_bf16():
|
||||
"""bf16 path — looser tolerance, must still match eager."""
|
||||
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
|
||||
scaling = head_dim ** -0.5
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
out_sdpa, _ = sdpa_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
def test_sdpa_parity_backward():
|
||||
"""Gradients flow through SDPA and match the eager path within
|
||||
bf16 tolerance — critical for any training-side parity claim."""
|
||||
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
|
||||
scaling = head_dim ** -0.5
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
|
||||
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
|
||||
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
|
||||
|
||||
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
|
||||
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
|
||||
|
||||
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
|
||||
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
|
||||
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)
|
||||
Reference in New Issue
Block a user