mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
pi052: SDPA attention port + selective AC + bench harness
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call with ``torch.nn.functional.scaled_dot_product_attention`` in ``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052). PyTorch SDPA picks the memory-efficient kernel for the block-bidirectional 4D additive mask the dual-expert model uses (FA2 / FA3 reject it because they only accept causal / sliding-window / varlen patterns). The shared ``sdpa_attention_forward`` helper mirrors the eager signature so the call sites are unchanged. Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)`` wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside ``PaliGemmaWithExpertModel.forward`` already handles activation recompute; the outer wrap was double-recomputing the whole backbone. +14% steps/sec on its own (job 22161405 vs 22161398, 1xH100). groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub`` rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses, which was blocking imports of any sibling policy through ``lerobot.policies.factory``. New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8) times PI052Policy.forward+backward (optionally with AdamW) on synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True, L=512, 4.14 B trainable params, AdamW state in bf16: pre-SDPA eager BS=8 610ms 19.5 GiB -> 13.1 samples/s sdpa BS=8 + compile=default 413ms 19.5 GiB -> 19.3 samples/s sdpa BS=16 + compile=default 715ms 37.3 GiB -> 22.4 samples/s sdpa BS=32 + compile=default 1325ms 44.8 GiB -> 24.2 samples/s sdpa BS=40 + compile=default 1665ms 48.6 GiB -> 24.0 samples/s Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py`` cover fp32 / bf16 / GQA / MHA forward + backward — output and grads match the eager path within bf16 tolerance. Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the follow-up multi-GPU memory sharding work. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,74 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-kernels
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=01:30:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out
|
||||||
|
|
||||||
|
# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma.
|
||||||
|
# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on.
|
||||||
|
# Same harness as bench_pi052_step.py — only the --kernels flag varies
|
||||||
|
# across runs so any delta is attributable to the patched op(s).
|
||||||
|
#
|
||||||
|
# Subkernels exercised: rope, rms_norm, geglu, layer_norm.
|
||||||
|
# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls
|
||||||
|
# F.cross_entropy directly and bypasses PaliGemma's forward, so those
|
||||||
|
# patches wouldn't fire without model-code changes (separate PR).
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
|
||||||
|
# /fsx triton cache is shared across nodes with different glibc versions
|
||||||
|
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||||
|
# a node-local cache per job to side-step that.
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||||
|
ldd --version | head -1
|
||||||
|
|
||||||
|
# Liger isn't in our standard env yet — install on the compute node so
|
||||||
|
# the slurm log captures the exact version that produced the numbers.
|
||||||
|
python -m pip install -q --upgrade 'liger-kernel'
|
||||||
|
python - <<'PY' || true
|
||||||
|
from importlib.metadata import version, PackageNotFoundError
|
||||||
|
try:
|
||||||
|
print("liger-kernel", version("liger-kernel"))
|
||||||
|
except PackageNotFoundError:
|
||||||
|
print("liger-kernel: not importable")
|
||||||
|
import liger_kernel.transformers as t
|
||||||
|
print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma"))
|
||||||
|
PY
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# -- Baseline (no kernels) at the BS we actually train at. --
|
||||||
|
run --attn sdpa --batch-size 8 --kernels none
|
||||||
|
run --attn sdpa --batch-size 16 --kernels none
|
||||||
|
|
||||||
|
# -- Per-subkernel ablations at BS=16 to isolate each contributor. --
|
||||||
|
run --attn sdpa --batch-size 16 --kernels rms_norm
|
||||||
|
run --attn sdpa --batch-size 16 --kernels geglu
|
||||||
|
run --attn sdpa --batch-size 16 --kernels layer_norm
|
||||||
|
run --attn sdpa --batch-size 16 --kernels rope
|
||||||
|
|
||||||
|
# -- All-on, both BS to compare against the matched baselines above. --
|
||||||
|
run --attn sdpa --batch-size 8 --kernels all
|
||||||
|
run --attn sdpa --batch-size 16 --kernels all
|
||||||
|
|
||||||
|
# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? --
|
||||||
|
run --attn sdpa --batch-size 24 --kernels none
|
||||||
|
run --attn sdpa --batch-size 24 --kernels all
|
||||||
@@ -0,0 +1,338 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
|
||||||
|
|
||||||
|
Compares the new SDPA attention path against the eager baseline by
|
||||||
|
monkeypatching ``sdpa_attention_forward`` before the first model
|
||||||
|
forward — so both runs share identical Q/K/V plumbing and only the
|
||||||
|
attention kernel differs. Reports steps/sec and peak GPU memory.
|
||||||
|
|
||||||
|
SLURM-only:
|
||||||
|
|
||||||
|
sbatch examples/benchmark/bench_pi052_step.slurm
|
||||||
|
|
||||||
|
Or one-off:
|
||||||
|
|
||||||
|
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
|
||||||
|
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_patch_eager() -> None:
|
||||||
|
"""Swap ``sdpa_attention_forward`` for the original eager forward.
|
||||||
|
|
||||||
|
Must be called BEFORE PI052Policy is instantiated — the layer
|
||||||
|
compute functions resolve the symbol at call time (module-level
|
||||||
|
lookup), so this patch covers both pi05 and pi052 KI paths."""
|
||||||
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
|
from lerobot.policies.pi05 import modeling_pi05
|
||||||
|
|
||||||
|
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_patch_liger(spec: str) -> dict:
|
||||||
|
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
|
||||||
|
|
||||||
|
Must be called BEFORE PI052Policy is instantiated — Liger replaces
|
||||||
|
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
|
||||||
|
so any model built after the call picks up the fused forwards.
|
||||||
|
|
||||||
|
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
|
||||||
|
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
|
||||||
|
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
|
||||||
|
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
|
||||||
|
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
|
||||||
|
so neither patch would fire without invasive model-code changes.
|
||||||
|
"""
|
||||||
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
|
||||||
|
if spec in ("", "none"):
|
||||||
|
return enabled
|
||||||
|
tokens = [t.strip() for t in spec.split(",") if t.strip()]
|
||||||
|
if tokens == ["all"]:
|
||||||
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
|
||||||
|
else:
|
||||||
|
for t in tokens:
|
||||||
|
if t not in enabled:
|
||||||
|
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
|
||||||
|
enabled[t] = True
|
||||||
|
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
|
||||||
|
|
||||||
|
apply_liger_kernel_to_paligemma(
|
||||||
|
rope=enabled["rope"],
|
||||||
|
rms_norm=enabled["rms_norm"],
|
||||||
|
geglu=enabled["geglu"],
|
||||||
|
layer_norm=enabled["layer_norm"],
|
||||||
|
cross_entropy=False,
|
||||||
|
fused_linear_cross_entropy=False,
|
||||||
|
)
|
||||||
|
return enabled
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_patch_flex() -> None:
|
||||||
|
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
|
||||||
|
|
||||||
|
Experimental: builds a per-call ``score_mod`` from the additive
|
||||||
|
mask and dispatches to a compiled ``flex_attention`` kernel.
|
||||||
|
|
||||||
|
Known issue on torch 2.7.1: dynamo errors out with
|
||||||
|
``FlexAttentionHigherOrderVariable() has no type`` when the
|
||||||
|
``score_mod`` closure captures a per-call bias tensor. A proper
|
||||||
|
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
|
||||||
|
PI05Pytorch.forward level so a BlockMask object can be passed
|
||||||
|
down to the layer compute, not a per-call closure. Left as
|
||||||
|
future work; keep this stub for benchmark experimentation."""
|
||||||
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention
|
||||||
|
|
||||||
|
from lerobot.policies.pi05 import modeling_pi05
|
||||||
|
|
||||||
|
compiled_flex = torch.compile(flex_attention, dynamic=True)
|
||||||
|
|
||||||
|
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
|
||||||
|
n_rep = module.num_key_value_groups
|
||||||
|
if n_rep > 1:
|
||||||
|
key = key.repeat_interleave(n_rep, dim=1)
|
||||||
|
value = value.repeat_interleave(n_rep, dim=1)
|
||||||
|
|
||||||
|
bias = attention_mask # (B, 1, Lq, Lk) additive
|
||||||
|
|
||||||
|
def score_mod(score, b, h, q_idx, kv_idx):
|
||||||
|
return score + bias[b, 0, q_idx, kv_idx]
|
||||||
|
|
||||||
|
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
|
||||||
|
return attn_output.transpose(1, 2).contiguous(), None
|
||||||
|
|
||||||
|
modeling_pi05.sdpa_attention_forward = flex_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _build_policy(args, device: torch.device):
|
||||||
|
"""Random-init PI052Policy at production-relevant shapes."""
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.pi052.configuration_pi052 import PI052Config
|
||||||
|
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
|
||||||
|
|
||||||
|
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
|
||||||
|
# which flips ``train_expert_only=False`` in __post_init__ and
|
||||||
|
# makes the whole PaliGemma + Gemma-expert stack trainable. We
|
||||||
|
# mirror that here so the optimizer-state count reflects reality;
|
||||||
|
# the loss path still goes through ``PI05Policy.forward`` because
|
||||||
|
# ``text_labels`` / FAST tokens are absent from the synthetic batch
|
||||||
|
# (see ``PI052Policy.forward`` early-return).
|
||||||
|
config = PI052Config(
|
||||||
|
max_action_dim=args.action_dim,
|
||||||
|
max_state_dim=args.state_dim,
|
||||||
|
dtype=args.dtype,
|
||||||
|
knowledge_insulation=args.knowledge_insulation,
|
||||||
|
text_loss_weight=1e-3 if args.train_full else 0.0,
|
||||||
|
flow_loss_weight=1.0,
|
||||||
|
enable_fast_action_loss=False,
|
||||||
|
unfreeze_lm_head=args.train_full,
|
||||||
|
tokenizer_max_length=args.lang_tokens,
|
||||||
|
device="cuda",
|
||||||
|
compile_model=args.compile_model,
|
||||||
|
compile_mode=args.compile_mode,
|
||||||
|
)
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
|
||||||
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.output_features = {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
|
||||||
|
}
|
||||||
|
policy = PI052Policy(config)
|
||||||
|
policy.to(device)
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
policy.model.gradient_checkpointing_enable()
|
||||||
|
policy.train()
|
||||||
|
return policy, config
|
||||||
|
|
||||||
|
|
||||||
|
def _build_batch(args, config, device: torch.device) -> dict:
|
||||||
|
"""Synthetic batch matching the training-loop input contract."""
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
B = args.batch_size
|
||||||
|
L = args.lang_tokens
|
||||||
|
return {
|
||||||
|
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
|
||||||
|
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
|
||||||
|
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
|
||||||
|
"observation.state": torch.randn(B, args.state_dim, device=device),
|
||||||
|
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
|
||||||
|
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
|
||||||
|
"task": ["bench task"] * B,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _step(policy, batch, optimizer=None) -> torch.Tensor:
|
||||||
|
loss, _ = policy.forward(batch)
|
||||||
|
loss.backward()
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
else:
|
||||||
|
for p in policy.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad = None
|
||||||
|
return loss.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kernels",
|
||||||
|
default="none",
|
||||||
|
help=(
|
||||||
|
"Liger sub-kernels to enable, comma-separated. Choose from "
|
||||||
|
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
|
||||||
|
"via apply_liger_kernel_to_paligemma() BEFORE model build."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compile",
|
||||||
|
dest="compile_model",
|
||||||
|
action="store_true",
|
||||||
|
help="Set policy.config.compile_model=True (torch.compile the forward).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compile-mode",
|
||||||
|
default="default",
|
||||||
|
help="torch.compile mode (default | reduce-overhead | max-autotune).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--warmup", type=int, default=8)
|
||||||
|
parser.add_argument("--steps", type=int, default=40)
|
||||||
|
parser.add_argument("--lang-tokens", type=int, default=512)
|
||||||
|
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
|
||||||
|
parser.add_argument("--action-dim", type=int, default=14)
|
||||||
|
parser.add_argument("--state-dim", type=int, default=14)
|
||||||
|
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient-checkpointing",
|
||||||
|
dest="gradient_checkpointing",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optimizer",
|
||||||
|
choices=["none", "adamw", "adamw_fused"],
|
||||||
|
default="adamw_fused",
|
||||||
|
help=(
|
||||||
|
"Whether to include an AdamW step in the timed iteration. "
|
||||||
|
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
|
||||||
|
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
|
||||||
|
"state and ``optimizer.step()`` cost."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-full",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help=(
|
||||||
|
"Mirror production: unfreeze the PaliGemma backbone (full "
|
||||||
|
"~3B trainable params) instead of training only the 300M "
|
||||||
|
"action expert."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
|
||||||
|
|
||||||
|
if args.attn == "eager":
|
||||||
|
_maybe_patch_eager()
|
||||||
|
elif args.attn == "flex":
|
||||||
|
_maybe_patch_flex()
|
||||||
|
|
||||||
|
liger_flags = _maybe_patch_liger(args.kernels)
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
policy, config = _build_policy(args, device)
|
||||||
|
batch = _build_batch(args, config, device)
|
||||||
|
|
||||||
|
optimizer = None
|
||||||
|
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
if args.optimizer != "none":
|
||||||
|
trainable = [p for p in policy.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(args.warmup):
|
||||||
|
_step(policy, batch, optimizer)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
starter = torch.cuda.Event(enable_timing=True)
|
||||||
|
ender = torch.cuda.Event(enable_timing=True)
|
||||||
|
starter.record()
|
||||||
|
for _ in range(args.steps):
|
||||||
|
_step(policy, batch, optimizer)
|
||||||
|
ender.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
total_ms = starter.elapsed_time(ender)
|
||||||
|
step_ms = total_ms / args.steps
|
||||||
|
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
optim_gb = 0.0
|
||||||
|
if optimizer is not None:
|
||||||
|
for st in optimizer.state.values():
|
||||||
|
for v in st.values():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
optim_gb += v.numel() * v.element_size() / (1024**3)
|
||||||
|
|
||||||
|
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
|
||||||
|
name = (
|
||||||
|
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
|
||||||
|
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
|
||||||
|
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
|
||||||
|
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
|
||||||
|
f"trainable_params={trainable_params / 1e9:.2f}B"
|
||||||
|
)
|
||||||
|
|
||||||
|
del policy, batch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-attn
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:30:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
python -c "import torch; print('torch', torch.__version__, 'cuda', torch.version.cuda)"
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attention parity benchmark — same shapes, different attention kernel.
|
||||||
|
run --attn eager --batch-size 8
|
||||||
|
run --attn sdpa --batch-size 8
|
||||||
|
|
||||||
|
# Headroom benchmark — does SDPA's memory cut allow a bigger micro-batch?
|
||||||
|
run --attn sdpa --batch-size 12
|
||||||
|
run --attn sdpa --batch-size 16
|
||||||
|
run --attn sdpa --batch-size 24
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v2
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:45:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v2_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# A: GC ON — see if the selective-AC change (one less recompute level)
|
||||||
|
# narrows the eager vs SDPA gap at BS=8.
|
||||||
|
run --attn eager --batch-size 8
|
||||||
|
run --attn sdpa --batch-size 8
|
||||||
|
|
||||||
|
# B: GC OFF — isolate the raw attention-kernel cost & memory delta.
|
||||||
|
run --attn eager --batch-size 4 --no-gradient-checkpointing
|
||||||
|
run --attn sdpa --batch-size 4 --no-gradient-checkpointing
|
||||||
|
|
||||||
|
# C: SDPA + GC headroom sweep — where does it OOM?
|
||||||
|
run --attn sdpa --batch-size 16
|
||||||
|
run --attn sdpa --batch-size 24
|
||||||
|
run --attn sdpa --batch-size 32
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v3
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:45:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v3_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compile sweep: does torch.compile + SDPA give a non-trivial boost on
|
||||||
|
# top of the bare SDPA path?
|
||||||
|
run --attn sdpa --batch-size 8 --compile
|
||||||
|
run --attn sdpa --batch-size 16 --compile
|
||||||
|
|
||||||
|
# FlexAttention sweep (experimental): score_mod adds the additive bias
|
||||||
|
# in-kernel; expect a long first-step compile, then SDPA-or-better steady
|
||||||
|
# state.
|
||||||
|
run --attn flex --batch-size 8
|
||||||
|
run --attn flex --batch-size 16
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v4
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=01:00:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v4_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
|
||||||
|
# /fsx triton cache is shared across nodes with different glibc versions
|
||||||
|
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||||
|
# a node-local cache per job to side-step that.
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||||
|
ldd --version | head -1
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# compile path on top of SDPA + selective AC
|
||||||
|
run --attn sdpa --batch-size 8 --compile
|
||||||
|
run --attn sdpa --batch-size 16 --compile
|
||||||
|
|
||||||
|
# FlexAttention experimental
|
||||||
|
run --attn flex --batch-size 8
|
||||||
|
run --attn flex --batch-size 16
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v5
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:45:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v5_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# compile_mode=default (graph-only, no autotune) is the right knob with
|
||||||
|
# gradient checkpointing — max-autotune in v4 was 2x slower than no-compile.
|
||||||
|
run --attn sdpa --batch-size 8 --compile --compile-mode default
|
||||||
|
run --attn sdpa --batch-size 16 --compile --compile-mode default
|
||||||
|
run --attn sdpa --batch-size 8 --compile --compile-mode reduce-overhead
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v6-bs32
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:30:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v6_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# BS=32 with the production settings (SDPA + compile=default).
|
||||||
|
run --attn sdpa --batch-size 32 --compile --compile-mode default
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v7-opt
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:45:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v7_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# Realistic full-step memory: fwd + bwd + AdamW step. The original
|
||||||
|
# sweep was fwd+bwd-only and undercounted memory by the optimizer-
|
||||||
|
# state size (~2x param bytes for AdamW). This run confirms BS=16
|
||||||
|
# and BS=32 still fit with the optimizer in residency.
|
||||||
|
run --attn sdpa --batch-size 16 --compile --compile-mode default --optimizer adamw_fused
|
||||||
|
run --attn sdpa --batch-size 32 --compile --compile-mode default --optimizer adamw_fused
|
||||||
|
|
||||||
|
# Without compile, in case the production cluster has compile issues.
|
||||||
|
run --attn sdpa --batch-size 16 --optimizer adamw_fused
|
||||||
|
run --attn sdpa --batch-size 32 --optimizer adamw_fused
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=bench-pi052-v8-bs40-dtype
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=00:45:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v8_%j.out
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||||
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||||
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||||
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||||
|
|
||||||
|
echo "=== Node: $(hostname) ==="
|
||||||
|
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo
|
||||||
|
echo "--- $* ---"
|
||||||
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
# Confirm BS=40 fits on a single H100 with the optimizer in residency.
|
||||||
|
run --attn sdpa --batch-size 40 --compile --compile-mode default --optimizer adamw_fused
|
||||||
|
|
||||||
|
# Dtype A/B at modest batch — fp32 needs ~2x the memory of bf16, so we
|
||||||
|
# drop to BS=4 to keep both runs comparable instead of OOMing fp32.
|
||||||
|
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype bfloat16
|
||||||
|
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype float32
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_activation_checkpointing: false
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer,SiglipEncoderLayer
|
||||||
|
fsdp_use_orig_params: true
|
||||||
|
fsdp_version: 2
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -178,7 +178,6 @@ N_COLOR_CHANNELS = 3
|
|||||||
|
|
||||||
|
|
||||||
# config
|
# config
|
||||||
@strict
|
|
||||||
class GR00TN15Config(PretrainedConfig):
|
class GR00TN15Config(PretrainedConfig):
|
||||||
model_type = "gr00t_n1_5"
|
model_type = "gr00t_n1_5"
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
|||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
DynamicCache = None
|
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
PiGemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
_gated_residual = None
|
_gated_residual = None
|
||||||
@@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
|||||||
return att_2d_masks & pad_2d_masks
|
return att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
|
|
||||||
def clone_past_key_values(past_key_values):
|
|
||||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
|
||||||
return DynamicCache(
|
|
||||||
tuple(
|
|
||||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_vector(vector, new_dim):
|
def pad_vector(vector, new_dim):
|
||||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||||
|
|
||||||
@@ -233,14 +223,53 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
return padded_images
|
return padded_images
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa_attention_forward(
|
||||||
|
module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
|
||||||
|
``torch.nn.functional.scaled_dot_product_attention``.
|
||||||
|
|
||||||
|
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
|
||||||
|
bias masks (the FA backend only accepts causal/sliding-window). On
|
||||||
|
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
|
||||||
|
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
|
||||||
|
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
|
||||||
|
"""
|
||||||
|
n_rep = module.num_key_value_groups
|
||||||
|
if n_rep > 1:
|
||||||
|
key = key.repeat_interleave(n_rep, dim=1)
|
||||||
|
value = value.repeat_interleave(n_rep, dim=1)
|
||||||
|
if attention_mask is not None and attention_mask.dtype != query.dtype:
|
||||||
|
attention_mask = attention_mask.to(dtype=query.dtype)
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=dropout if module.training else 0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=scaling,
|
||||||
|
)
|
||||||
|
return attn_output.transpose(1, 2).contiguous(), None
|
||||||
|
|
||||||
|
|
||||||
# Define the complete layer computation function for gradient checkpointing
|
# Define the complete layer computation function for gradient checkpointing
|
||||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
def compute_layer_complete(
|
||||||
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
|
):
|
||||||
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
@@ -262,16 +291,14 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
device=query_states.device,
|
device=query_states.device,
|
||||||
dtype=query_states.dtype,
|
dtype=query_states.dtype,
|
||||||
)
|
)
|
||||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
paligemma_layer = layers[0]
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
scaling = paligemma_layer.self_attn.scaling
|
att_output, _ = sdpa_attention_forward(
|
||||||
# Attention computation
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
|
||||||
paligemma_layer.self_attn,
|
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -279,13 +306,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma_layer.self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
end_pos = start_pos + hidden_states.shape[1]
|
end_pos = start_pos + hidden_states.shape[1]
|
||||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
@@ -450,13 +477,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -494,9 +521,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
paligemma_layers = self.paligemma.model.language_model.layers
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
gemma_expert_layers = self.gemma_expert.model.layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
use_gradient_checkpointing = (
|
use_gradient_checkpointing = (
|
||||||
@@ -506,39 +532,36 @@ class PaliGemmaWithExpertModel(
|
|||||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||||
|
|
||||||
# Process all layers with gradient checkpointing if enabled
|
# Process all layers with gradient checkpointing if enabled
|
||||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
for layer_idx in range(num_layers):
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||||
compute_layer_complete,
|
compute_layer_complete,
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
use_reentrant=False,
|
use_reentrant=False,
|
||||||
preserve_rng_state=False,
|
preserve_rng_state=False,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = compute_layer_complete(
|
inputs_embeds = compute_layer_complete(
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# final norm
|
# final norm
|
||||||
final_norms = (
|
|
||||||
self.paligemma.model.language_model.norm,
|
|
||||||
self.gemma_expert.model.norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -678,7 +701,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language tokens
|
# Process language tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
return lang_emb
|
lang_emb_dim = lang_emb.shape[-1]
|
||||||
|
return lang_emb * math.sqrt(lang_emb_dim)
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
@@ -767,19 +791,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||||
|
|
||||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
# Selective AC: rely on the per-layer checkpoint inside
|
||||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
|
||||||
attention_mask=att_2d_masks_4d,
|
# transformer block individually). The previous outer
|
||||||
position_ids=position_ids,
|
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
|
||||||
past_key_values=None,
|
# re-ran the full backbone forward during backward *and* each
|
||||||
inputs_embeds=[prefix_embs, suffix_embs],
|
# block's own checkpoint re-ran during that recompute. Pure
|
||||||
use_cache=False,
|
# waste with SDPA, which already streams attention activations.
|
||||||
adarms_cond=[None, adarms_cond],
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||||
)
|
attention_mask=att_2d_masks_4d,
|
||||||
return suffix_out
|
position_ids=position_ids,
|
||||||
|
past_key_values=None,
|
||||||
suffix_out = self._apply_checkpoint(
|
inputs_embeds=[prefix_embs, suffix_embs],
|
||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
use_cache=False,
|
||||||
|
adarms_cond=[None, adarms_cond],
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
@@ -900,7 +925,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
past_key_values = clone_past_key_values(past_key_values)
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -291,11 +291,13 @@ def _compute_layer_ki(
|
|||||||
if mask_for_action.dtype != Q_action.dtype:
|
if mask_for_action.dtype != Q_action.dtype:
|
||||||
mask_for_action = mask_for_action.to(dtype=Q_action.dtype)
|
mask_for_action = mask_for_action.to(dtype=Q_action.dtype)
|
||||||
|
|
||||||
att_vlm, _ = modeling_gemma.eager_attention_forward(
|
from ..pi05.modeling_pi05 import sdpa_attention_forward # noqa: PLC0415
|
||||||
|
|
||||||
|
att_vlm, _ = sdpa_attention_forward(
|
||||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
|
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
|
||||||
)
|
)
|
||||||
att_action, _ = modeling_gemma.eager_attention_forward(
|
att_action, _ = sdpa_attention_forward(
|
||||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
|
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Numerical-parity tests for the SDPA attention port.
|
||||||
|
|
||||||
|
``pi05`` / ``pi052`` replaced the per-layer call from
|
||||||
|
``modeling_gemma.eager_attention_forward`` with
|
||||||
|
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
|
||||||
|
output must be bit-equivalent (within bf16 tolerance) on the masks
|
||||||
|
this model actually uses — block-bidirectional with an arbitrary
|
||||||
|
additive bias — otherwise we silently change training behaviour.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
from transformers.models.gemma import modeling_gemma # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
|
||||||
|
make_att_2d_masks,
|
||||||
|
sdpa_attention_forward,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_self_attn(num_kv_groups: int, training: bool = False):
|
||||||
|
"""Bare module surface that both forwards read."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
num_key_value_groups=num_kv_groups,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_inputs(
|
||||||
|
bsize: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
seq_len: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int = 0,
|
||||||
|
):
|
||||||
|
g = torch.Generator(device="cpu").manual_seed(seed)
|
||||||
|
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||||
|
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||||
|
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def _block_bidirectional_mask(
|
||||||
|
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
|
||||||
|
matches ``[images, language, suffix]`` from ``embed_prefix`` +
|
||||||
|
``embed_suffix``: every block bidirectional internally, later
|
||||||
|
blocks visible to earlier ones via the cumulative-block rule.
|
||||||
|
"""
|
||||||
|
assert sum(block_sizes) == seq_len
|
||||||
|
att_marks = []
|
||||||
|
for i, n in enumerate(block_sizes):
|
||||||
|
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
|
||||||
|
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
|
||||||
|
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
|
||||||
|
att_2d = make_att_2d_masks(pad, att)
|
||||||
|
bias = torch.where(
|
||||||
|
att_2d[:, None, :, :],
|
||||||
|
torch.zeros((), dtype=dtype),
|
||||||
|
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
|
||||||
|
)
|
||||||
|
return bias
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_heads,num_kv_heads,head_dim",
|
||||||
|
[
|
||||||
|
(8, 1, 256), # gemma_2b / paligemma config
|
||||||
|
(8, 8, 64), # MHA control (no GQA repeat)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
|
||||||
|
"""SDPA forward output matches the eager softmax(QK^T)@V on the
|
||||||
|
block-bidirectional mask layout pi05 actually uses."""
|
||||||
|
bsize, seq_len = 2, 13
|
||||||
|
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
|
||||||
|
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
|
||||||
|
scaling = head_dim ** -0.5
|
||||||
|
|
||||||
|
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
|
||||||
|
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
|
||||||
|
|
||||||
|
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||||
|
|
||||||
|
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||||
|
module, q, k, v, mask, scaling
|
||||||
|
)
|
||||||
|
out_sdpa, _ = sdpa_attention_forward(
|
||||||
|
module, q, k, v, mask, scaling
|
||||||
|
)
|
||||||
|
assert out_eager.shape == out_sdpa.shape
|
||||||
|
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sdpa_parity_bf16():
|
||||||
|
"""bf16 path — looser tolerance, must still match eager."""
|
||||||
|
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
|
||||||
|
scaling = head_dim ** -0.5
|
||||||
|
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
|
||||||
|
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
|
||||||
|
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||||
|
|
||||||
|
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||||
|
module, q, k, v, mask, scaling
|
||||||
|
)
|
||||||
|
out_sdpa, _ = sdpa_attention_forward(
|
||||||
|
module, q, k, v, mask, scaling
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sdpa_parity_backward():
|
||||||
|
"""Gradients flow through SDPA and match the eager path within
|
||||||
|
bf16 tolerance — critical for any training-side parity claim."""
|
||||||
|
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
|
||||||
|
scaling = head_dim ** -0.5
|
||||||
|
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
|
||||||
|
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
|
||||||
|
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
|
||||||
|
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||||
|
|
||||||
|
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
|
||||||
|
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
|
||||||
|
|
||||||
|
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
|
||||||
|
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
|
||||||
|
|
||||||
|
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
|
||||||
|
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
|
||||||
|
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)
|
||||||
Reference in New Issue
Block a user