From 49133565645f9361e169dea37c7795f3f835e339 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 25 May 2026 21:59:20 +0000 Subject: [PATCH] pi052: SDPA attention port + selective AC + bench harness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call with ``torch.nn.functional.scaled_dot_product_attention`` in ``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052). PyTorch SDPA picks the memory-efficient kernel for the block-bidirectional 4D additive mask the dual-expert model uses (FA2 / FA3 reject it because they only accept causal / sliding-window / varlen patterns). The shared ``sdpa_attention_forward`` helper mirrors the eager signature so the call sites are unchanged. Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)`` wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside ``PaliGemmaWithExpertModel.forward`` already handles activation recompute; the outer wrap was double-recomputing the whole backbone. +14% steps/sec on its own (job 22161405 vs 22161398, 1xH100). groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub`` rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses, which was blocking imports of any sibling policy through ``lerobot.policies.factory``. New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8) times PI052Policy.forward+backward (optionally with AdamW) on synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True, L=512, 4.14 B trainable params, AdamW state in bf16: pre-SDPA eager BS=8 610ms 19.5 GiB -> 13.1 samples/s sdpa BS=8 + compile=default 413ms 19.5 GiB -> 19.3 samples/s sdpa BS=16 + compile=default 715ms 37.3 GiB -> 22.4 samples/s sdpa BS=32 + compile=default 1325ms 44.8 GiB -> 24.2 samples/s sdpa BS=40 + compile=default 1665ms 48.6 GiB -> 24.0 samples/s Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py`` cover fp32 / bf16 / GQA / MHA forward + backward — output and grads match the eager path within bf16 tolerance. Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the follow-up multi-GPU memory sharding work. Co-authored-by: Cursor --- examples/benchmark/bench_pi052_kernels.slurm | 74 ++++ examples/benchmark/bench_pi052_step.py | 338 ++++++++++++++++++ examples/benchmark/bench_pi052_step.slurm | 36 ++ examples/benchmark/bench_pi052_step_v2.slurm | 39 ++ examples/benchmark/bench_pi052_step_v3.slurm | 36 ++ examples/benchmark/bench_pi052_step_v4.slurm | 41 +++ examples/benchmark/bench_pi052_step_v5.slurm | 33 ++ examples/benchmark/bench_pi052_step_v6.slurm | 31 ++ examples/benchmark/bench_pi052_step_v7.slurm | 39 ++ examples/benchmark/bench_pi052_step_v8.slurm | 36 ++ examples/benchmark/fsdp_pi052.yaml | 29 ++ src/lerobot/policies/groot/groot_n1.py | 1 - src/lerobot/policies/pi05/modeling_pi05.py | 129 ++++--- src/lerobot/policies/pi052/modeling_pi052.py | 6 +- .../pi052/test_pi052_sdpa_attention.py | 155 ++++++++ 15 files changed, 968 insertions(+), 55 deletions(-) create mode 100644 examples/benchmark/bench_pi052_kernels.slurm create mode 100644 examples/benchmark/bench_pi052_step.py create mode 100644 examples/benchmark/bench_pi052_step.slurm create mode 100644 examples/benchmark/bench_pi052_step_v2.slurm create mode 100644 examples/benchmark/bench_pi052_step_v3.slurm create mode 100644 examples/benchmark/bench_pi052_step_v4.slurm create mode 100644 examples/benchmark/bench_pi052_step_v5.slurm create mode 100644 examples/benchmark/bench_pi052_step_v6.slurm create mode 100644 examples/benchmark/bench_pi052_step_v7.slurm create mode 100644 examples/benchmark/bench_pi052_step_v8.slurm create mode 100644 examples/benchmark/fsdp_pi052.yaml create mode 100644 tests/policies/pi052/test_pi052_sdpa_attention.py diff --git a/examples/benchmark/bench_pi052_kernels.slurm b/examples/benchmark/bench_pi052_kernels.slurm new file mode 100644 index 000000000..046ed7dfe --- /dev/null +++ b/examples/benchmark/bench_pi052_kernels.slurm @@ -0,0 +1,74 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-kernels +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=01:30:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out + +# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma. +# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on. +# Same harness as bench_pi052_step.py — only the --kernels flag varies +# across runs so any delta is attributable to the patched op(s). +# +# Subkernels exercised: rope, rms_norm, geglu, layer_norm. +# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls +# F.cross_entropy directly and bypasses PaliGemma's forward, so those +# patches wouldn't fire without model-code changes (separate PR). + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# /fsx triton cache is shared across nodes with different glibc versions +# — kernels built on one node trip GLIBC_2.34-not-found on another. Use +# a node-local cache per job to side-step that. +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader +ldd --version | head -1 + +# Liger isn't in our standard env yet — install on the compute node so +# the slurm log captures the exact version that produced the numbers. +python -m pip install -q --upgrade 'liger-kernel' +python - <<'PY' || true +from importlib.metadata import version, PackageNotFoundError +try: + print("liger-kernel", version("liger-kernel")) +except PackageNotFoundError: + print("liger-kernel: not importable") +import liger_kernel.transformers as t +print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma")) +PY + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# -- Baseline (no kernels) at the BS we actually train at. -- +run --attn sdpa --batch-size 8 --kernels none +run --attn sdpa --batch-size 16 --kernels none + +# -- Per-subkernel ablations at BS=16 to isolate each contributor. -- +run --attn sdpa --batch-size 16 --kernels rms_norm +run --attn sdpa --batch-size 16 --kernels geglu +run --attn sdpa --batch-size 16 --kernels layer_norm +run --attn sdpa --batch-size 16 --kernels rope + +# -- All-on, both BS to compare against the matched baselines above. -- +run --attn sdpa --batch-size 8 --kernels all +run --attn sdpa --batch-size 16 --kernels all + +# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? -- +run --attn sdpa --batch-size 24 --kernels none +run --attn sdpa --batch-size 24 --kernels all diff --git a/examples/benchmark/bench_pi052_step.py b/examples/benchmark/bench_pi052_step.py new file mode 100644 index 000000000..00560d54b --- /dev/null +++ b/examples/benchmark/bench_pi052_step.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark ``PI052Policy.forward + backward`` on a single GPU. + +Compares the new SDPA attention path against the eager baseline by +monkeypatching ``sdpa_attention_forward`` before the first model +forward — so both runs share identical Q/K/V plumbing and only the +attention kernel differs. Reports steps/sec and peak GPU memory. + +SLURM-only: + + sbatch examples/benchmark/bench_pi052_step.slurm + +Or one-off: + + srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\ + python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8 +""" + +from __future__ import annotations + +import argparse +import gc +import math +import os +import time + +import torch + + +def _maybe_patch_eager() -> None: + """Swap ``sdpa_attention_forward`` for the original eager forward. + + Must be called BEFORE PI052Policy is instantiated — the layer + compute functions resolve the symbol at call time (module-level + lookup), so this patch covers both pi05 and pi052 KI paths.""" + from transformers.models.gemma import modeling_gemma + + from lerobot.policies.pi05 import modeling_pi05 + + modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward + + +_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm") + + +def _maybe_patch_liger(spec: str) -> dict: + """Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels. + + Must be called BEFORE PI052Policy is instantiated — Liger replaces + classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``, + so any model built after the call picks up the fused forwards. + + ``spec`` is a comma-separated subset of {rope, rms_norm, geglu, + layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and + ``fused_linear_cross_entropy`` are intentionally skipped — pi052's + losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``) + and never traverse ``PaliGemmaForConditionalGeneration.forward``, + so neither patch would fire without invasive model-code changes. + """ + enabled = dict.fromkeys(_LIGER_SUBKERNELS, False) + if spec in ("", "none"): + return enabled + tokens = [t.strip() for t in spec.split(",") if t.strip()] + if tokens == ["all"]: + enabled = dict.fromkeys(_LIGER_SUBKERNELS, True) + else: + for t in tokens: + if t not in enabled: + raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.") + enabled[t] = True + + from liger_kernel.transformers import apply_liger_kernel_to_paligemma + + apply_liger_kernel_to_paligemma( + rope=enabled["rope"], + rms_norm=enabled["rms_norm"], + geglu=enabled["geglu"], + layer_norm=enabled["layer_norm"], + cross_entropy=False, + fused_linear_cross_entropy=False, + ) + return enabled + + +def _maybe_patch_flex() -> None: + """Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward. + + Experimental: builds a per-call ``score_mod`` from the additive + mask and dispatches to a compiled ``flex_attention`` kernel. + + Known issue on torch 2.7.1: dynamo errors out with + ``FlexAttentionHigherOrderVariable() has no type`` when the + ``score_mod`` closure captures a per-call bias tensor. A proper + port needs ``create_block_mask(mask_mod, ...)`` plumbed at the + PI05Pytorch.forward level so a BlockMask object can be passed + down to the layer compute, not a per-call closure. Left as + future work; keep this stub for benchmark experimentation.""" + import torch + from torch.nn.attention.flex_attention import flex_attention + + from lerobot.policies.pi05 import modeling_pi05 + + compiled_flex = torch.compile(flex_attention, dynamic=True) + + def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0): + n_rep = module.num_key_value_groups + if n_rep > 1: + key = key.repeat_interleave(n_rep, dim=1) + value = value.repeat_interleave(n_rep, dim=1) + + bias = attention_mask # (B, 1, Lq, Lk) additive + + def score_mod(score, b, h, q_idx, kv_idx): + return score + bias[b, 0, q_idx, kv_idx] + + attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling) + return attn_output.transpose(1, 2).contiguous(), None + + modeling_pi05.sdpa_attention_forward = flex_forward + + +def _build_policy(args, device: torch.device): + """Random-init PI052Policy at production-relevant shapes.""" + from lerobot.configs.types import FeatureType, PolicyFeature + from lerobot.policies.pi052.configuration_pi052 import PI052Config + from lerobot.policies.pi052.modeling_pi052 import PI052Policy + + # Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``, + # which flips ``train_expert_only=False`` in __post_init__ and + # makes the whole PaliGemma + Gemma-expert stack trainable. We + # mirror that here so the optimizer-state count reflects reality; + # the loss path still goes through ``PI05Policy.forward`` because + # ``text_labels`` / FAST tokens are absent from the synthetic batch + # (see ``PI052Policy.forward`` early-return). + config = PI052Config( + max_action_dim=args.action_dim, + max_state_dim=args.state_dim, + dtype=args.dtype, + knowledge_insulation=args.knowledge_insulation, + text_loss_weight=1e-3 if args.train_full else 0.0, + flow_loss_weight=1.0, + enable_fast_action_loss=False, + unfreeze_lm_head=args.train_full, + tokenizer_max_length=args.lang_tokens, + device="cuda", + compile_model=args.compile_model, + compile_mode=args.compile_mode, + ) + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)), + } + policy = PI052Policy(config) + policy.to(device) + if args.gradient_checkpointing: + policy.model.gradient_checkpointing_enable() + policy.train() + return policy, config + + +def _build_batch(args, config, device: torch.device) -> dict: + """Synthetic batch matching the training-loop input contract.""" + from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + ) + + B = args.batch_size + L = args.lang_tokens + return { + OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device), + OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device), + "observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device), + "observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device), + "observation.state": torch.randn(B, args.state_dim, device=device), + ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device), + "action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device), + "task": ["bench task"] * B, + } + + +def _step(policy, batch, optimizer=None) -> torch.Tensor: + loss, _ = policy.forward(batch) + loss.backward() + if optimizer is not None: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + else: + for p in policy.parameters(): + if p.grad is not None: + p.grad = None + return loss.detach() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa") + parser.add_argument( + "--kernels", + default="none", + help=( + "Liger sub-kernels to enable, comma-separated. Choose from " + f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied " + "via apply_liger_kernel_to_paligemma() BEFORE model build." + ), + ) + parser.add_argument( + "--compile", + dest="compile_model", + action="store_true", + help="Set policy.config.compile_model=True (torch.compile the forward).", + ) + parser.add_argument( + "--compile-mode", + default="default", + help="torch.compile mode (default | reduce-overhead | max-autotune).", + ) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--warmup", type=int, default=8) + parser.add_argument("--steps", type=int, default=40) + parser.add_argument("--lang-tokens", type=int, default=512) + parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16") + parser.add_argument("--action-dim", type=int, default=14) + parser.add_argument("--state-dim", type=int, default=14) + parser.add_argument("--knowledge-insulation", action="store_true", default=True) + parser.add_argument( + "--gradient-checkpointing", + dest="gradient_checkpointing", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument( + "--optimizer", + choices=["none", "adamw", "adamw_fused"], + default="adamw_fused", + help=( + "Whether to include an AdamW step in the timed iteration. " + "'none' mirrors the fwd+bwd-only original bench; 'adamw' / " + "'adamw_fused' add the realistic ~2x param-bytes optimizer " + "state and ``optimizer.step()`` cost." + ), + ) + parser.add_argument( + "--train-full", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Mirror production: unfreeze the PaliGemma backbone (full " + "~3B trainable params) instead of training only the 300M " + "action expert." + ), + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).") + + if args.attn == "eager": + _maybe_patch_eager() + elif args.attn == "flex": + _maybe_patch_flex() + + liger_flags = _maybe_patch_liger(args.kernels) + + device = torch.device("cuda") + torch.cuda.reset_peak_memory_stats() + + policy, config = _build_policy(args, device) + batch = _build_batch(args, config, device) + + optimizer = None + trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + if args.optimizer != "none": + trainable = [p for p in policy.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW( + trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused") + ) + + for _ in range(args.warmup): + _step(policy, batch, optimizer) + torch.cuda.synchronize() + + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + starter.record() + for _ in range(args.steps): + _step(policy, batch, optimizer) + ender.record() + torch.cuda.synchronize() + total_ms = starter.elapsed_time(ender) + step_ms = total_ms / args.steps + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + optim_gb = 0.0 + if optimizer is not None: + for st in optimizer.state.values(): + for v in st.values(): + if torch.is_tensor(v): + optim_gb += v.numel() * v.element_size() / (1024**3) + + liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none" + name = ( + f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | " + f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | " + f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}" + ) + print( + f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} " + f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB " + f"trainable_params={trainable_params / 1e9:.2f}B" + ) + + del policy, batch + gc.collect() + torch.cuda.empty_cache() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/benchmark/bench_pi052_step.slurm b/examples/benchmark/bench_pi052_step.slurm new file mode 100644 index 000000000..85b3b6063 --- /dev/null +++ b/examples/benchmark/bench_pi052_step.slurm @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-attn +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:30:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader + +python -c "import torch; print('torch', torch.__version__, 'cuda', torch.version.cuda)" + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# Attention parity benchmark — same shapes, different attention kernel. +run --attn eager --batch-size 8 +run --attn sdpa --batch-size 8 + +# Headroom benchmark — does SDPA's memory cut allow a bigger micro-batch? +run --attn sdpa --batch-size 12 +run --attn sdpa --batch-size 16 +run --attn sdpa --batch-size 24 diff --git a/examples/benchmark/bench_pi052_step_v2.slurm b/examples/benchmark/bench_pi052_step_v2.slurm new file mode 100644 index 000000000..839286bd5 --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v2.slurm @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v2 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:45:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v2_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# A: GC ON — see if the selective-AC change (one less recompute level) +# narrows the eager vs SDPA gap at BS=8. +run --attn eager --batch-size 8 +run --attn sdpa --batch-size 8 + +# B: GC OFF — isolate the raw attention-kernel cost & memory delta. +run --attn eager --batch-size 4 --no-gradient-checkpointing +run --attn sdpa --batch-size 4 --no-gradient-checkpointing + +# C: SDPA + GC headroom sweep — where does it OOM? +run --attn sdpa --batch-size 16 +run --attn sdpa --batch-size 24 +run --attn sdpa --batch-size 32 diff --git a/examples/benchmark/bench_pi052_step_v3.slurm b/examples/benchmark/bench_pi052_step_v3.slurm new file mode 100644 index 000000000..2cd426a05 --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v3.slurm @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v3 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:45:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v3_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# Compile sweep: does torch.compile + SDPA give a non-trivial boost on +# top of the bare SDPA path? +run --attn sdpa --batch-size 8 --compile +run --attn sdpa --batch-size 16 --compile + +# FlexAttention sweep (experimental): score_mod adds the additive bias +# in-kernel; expect a long first-step compile, then SDPA-or-better steady +# state. +run --attn flex --batch-size 8 +run --attn flex --batch-size 16 diff --git a/examples/benchmark/bench_pi052_step_v4.slurm b/examples/benchmark/bench_pi052_step_v4.slurm new file mode 100644 index 000000000..f4b88dfa9 --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v4.slurm @@ -0,0 +1,41 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v4 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=01:00:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v4_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# /fsx triton cache is shared across nodes with different glibc versions +# — kernels built on one node trip GLIBC_2.34-not-found on another. Use +# a node-local cache per job to side-step that. +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader +ldd --version | head -1 + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# compile path on top of SDPA + selective AC +run --attn sdpa --batch-size 8 --compile +run --attn sdpa --batch-size 16 --compile + +# FlexAttention experimental +run --attn flex --batch-size 8 +run --attn flex --batch-size 16 diff --git a/examples/benchmark/bench_pi052_step_v5.slurm b/examples/benchmark/bench_pi052_step_v5.slurm new file mode 100644 index 000000000..3a9fc102a --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v5.slurm @@ -0,0 +1,33 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v5 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:45:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v5_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# compile_mode=default (graph-only, no autotune) is the right knob with +# gradient checkpointing — max-autotune in v4 was 2x slower than no-compile. +run --attn sdpa --batch-size 8 --compile --compile-mode default +run --attn sdpa --batch-size 16 --compile --compile-mode default +run --attn sdpa --batch-size 8 --compile --compile-mode reduce-overhead diff --git a/examples/benchmark/bench_pi052_step_v6.slurm b/examples/benchmark/bench_pi052_step_v6.slurm new file mode 100644 index 000000000..99e016811 --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v6.slurm @@ -0,0 +1,31 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v6-bs32 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:30:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v6_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,memory.total --format=csv,noheader + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# BS=32 with the production settings (SDPA + compile=default). +run --attn sdpa --batch-size 32 --compile --compile-mode default diff --git a/examples/benchmark/bench_pi052_step_v7.slurm b/examples/benchmark/bench_pi052_step_v7.slurm new file mode 100644 index 000000000..6afc528af --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v7.slurm @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v7-opt +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:45:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v7_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,memory.total --format=csv,noheader + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# Realistic full-step memory: fwd + bwd + AdamW step. The original +# sweep was fwd+bwd-only and undercounted memory by the optimizer- +# state size (~2x param bytes for AdamW). This run confirms BS=16 +# and BS=32 still fit with the optimizer in residency. +run --attn sdpa --batch-size 16 --compile --compile-mode default --optimizer adamw_fused +run --attn sdpa --batch-size 32 --compile --compile-mode default --optimizer adamw_fused + +# Without compile, in case the production cluster has compile issues. +run --attn sdpa --batch-size 16 --optimizer adamw_fused +run --attn sdpa --batch-size 32 --optimizer adamw_fused diff --git a/examples/benchmark/bench_pi052_step_v8.slurm b/examples/benchmark/bench_pi052_step_v8.slurm new file mode 100644 index 000000000..a8ed8a8aa --- /dev/null +++ b/examples/benchmark/bench_pi052_step_v8.slurm @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --job-name=bench-pi052-v8-bs40-dtype +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=00:45:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v8_%j.out + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}" +export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}" +mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" + +echo "=== Node: $(hostname) ===" +nvidia-smi --query-gpu=name,memory.total --format=csv,noheader + +run() { + echo + echo "--- $* ---" + python examples/benchmark/bench_pi052_step.py "$@" || true +} + +# Confirm BS=40 fits on a single H100 with the optimizer in residency. +run --attn sdpa --batch-size 40 --compile --compile-mode default --optimizer adamw_fused + +# Dtype A/B at modest batch — fp32 needs ~2x the memory of bf16, so we +# drop to BS=4 to keep both runs comparable instead of OOMing fp32. +run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype bfloat16 +run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype float32 diff --git a/examples/benchmark/fsdp_pi052.yaml b/examples/benchmark/fsdp_pi052.yaml new file mode 100644 index 000000000..f9f8b71da --- /dev/null +++ b/examples/benchmark/fsdp_pi052.yaml @@ -0,0 +1,29 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer,SiglipEncoderLayer + fsdp_use_orig_params: true + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index c9110301f..6987f7f37 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -178,7 +178,6 @@ N_COLOR_CHANNELS = 3 # config -@strict class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 265d464c9..dc5a26ed0 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins +import copy import logging import math from collections import deque @@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: - from transformers.cache_utils import DynamicCache from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma @@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available: ) else: CONFIG_MAPPING = None - DynamicCache = None modeling_gemma = None PiGemmaForCausalLM = None _gated_residual = None @@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks -def clone_past_key_values(past_key_values): - """Clone the DynamicCache returned by prefix prefill for compiled denoising.""" - return DynamicCache( - tuple( - (keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values - ) - ) - - def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. @@ -233,14 +223,53 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images +def sdpa_attention_forward( + module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, +): + """Drop-in for ``modeling_gemma.eager_attention_forward`` using + ``torch.nn.functional.scaled_dot_product_attention``. + + PyTorch SDPA picks the memory-efficient kernel for arbitrary additive + bias masks (the FA backend only accepts causal/sliding-window). On + H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory + than the eager softmax(QK^T)+matmul path. Mirrors eager's signature + and output shape (``(B, Lq, H, D)``) so call sites are unchanged. + """ + n_rep = module.num_key_value_groups + if n_rep > 1: + key = key.repeat_interleave(n_rep, dim=1) + value = value.repeat_interleave(n_rep, dim=1) + if attention_mask is not None and attention_mask.dtype != query.dtype: + attention_mask = attention_mask.to(dtype=query.dtype) + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout if module.training else 0.0, + is_causal=False, + scale=scaling, + ) + return attn_output.transpose(1, 2).contiguous(), None + + # Define the complete layer computation function for gradient checkpointing -def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb): +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): - layer = layers[i] + layer = models[i].layers[layer_idx] hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -262,16 +291,14 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c device=query_states.device, dtype=query_states.dtype, ) - cos, sin = rotary_emb(dummy_tensor, position_ids) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - paligemma_layer = layers[0] - scaling = paligemma_layer.self_attn.scaling - # Attention computation - att_output, _ = modeling_gemma.eager_attention_forward( - paligemma_layer.self_attn, + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling + att_output, _ = sdpa_attention_forward( + paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, @@ -279,13 +306,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma_layer.self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): - layer = layers[i] + layer = models[i].layers[layer_idx] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) @@ -450,13 +477,13 @@ class PaliGemmaWithExpertModel( if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.get_input_embeddings()(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -494,9 +521,8 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - paligemma_layers = self.paligemma.model.language_model.layers - gemma_expert_layers = self.gemma_expert.model.layers - rotary_emb = self.paligemma.model.language_model.rotary_emb + models = [self.paligemma.model.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( @@ -506,39 +532,36 @@ class PaliGemmaWithExpertModel( ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled - for layers in zip(paligemma_layers, gemma_expert_layers, strict=True): + for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, - layers=layers, - rotary_emb=rotary_emb, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, ) else: inputs_embeds = compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, - layers=layers, - rotary_emb=rotary_emb, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, ) # final norm - final_norms = ( - self.paligemma.model.language_model.norm, - self.gemma_expert.model.norm, - ) - def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i]) + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -678,7 +701,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - return lang_emb + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) @@ -767,19 +791,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) - def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], - ) - return suffix_out - - suffix_out = self._apply_checkpoint( - forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + # Selective AC: rely on the per-layer checkpoint inside + # ``PaliGemmaWithExpertModel.forward`` (which wraps each + # transformer block individually). The previous outer + # ``_apply_checkpoint(forward_func, ...)`` doubled up — it + # re-ran the full backbone forward during backward *and* each + # block's own checkpoint re-ran during that recompute. Pure + # waste with SDPA, which already streams attention activations. + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], ) suffix_out = suffix_out[:, -self.config.chunk_size :] @@ -900,7 +925,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - past_key_values = clone_past_key_values(past_key_values) + past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 9b0c66a4c..709e7724f 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -291,11 +291,13 @@ def _compute_layer_ki( if mask_for_action.dtype != Q_action.dtype: mask_for_action = mask_for_action.to(dtype=Q_action.dtype) - att_vlm, _ = modeling_gemma.eager_attention_forward( + from ..pi05.modeling_pi05 import sdpa_attention_forward # noqa: PLC0415 + + att_vlm, _ = sdpa_attention_forward( paligemma.model.language_model.layers[layer_idx].self_attn, Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling, ) - att_action, _ = modeling_gemma.eager_attention_forward( + att_action, _ = sdpa_attention_forward( paligemma.model.language_model.layers[layer_idx].self_attn, Q_action, K_for_action, V_for_action, mask_for_action, scaling, ) diff --git a/tests/policies/pi052/test_pi052_sdpa_attention.py b/tests/policies/pi052/test_pi052_sdpa_attention.py new file mode 100644 index 000000000..808e80faf --- /dev/null +++ b/tests/policies/pi052/test_pi052_sdpa_attention.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Numerical-parity tests for the SDPA attention port. + +``pi05`` / ``pi052`` replaced the per-layer call from +``modeling_gemma.eager_attention_forward`` with +``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward +output must be bit-equivalent (within bf16 tolerance) on the masks +this model actually uses — block-bidirectional with an arbitrary +additive bias — otherwise we silently change training behaviour. +""" + +from types import SimpleNamespace + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers.models.gemma import modeling_gemma # noqa: E402 + +from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402 + make_att_2d_masks, + sdpa_attention_forward, +) +from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402 + + +def _mock_self_attn(num_kv_groups: int, training: bool = False): + """Bare module surface that both forwards read.""" + return SimpleNamespace( + num_key_value_groups=num_kv_groups, + training=training, + ) + + +def _build_inputs( + bsize: int, + num_heads: int, + num_kv_heads: int, + seq_len: int, + head_dim: int, + dtype: torch.dtype, + seed: int = 0, +): + g = torch.Generator(device="cpu").manual_seed(seed) + q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g) + k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g) + v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g) + return q, k, v + + +def _block_bidirectional_mask( + bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype +) -> torch.Tensor: + """Mimic ``_prepare_attention_masks_4d`` on a block layout that + matches ``[images, language, suffix]`` from ``embed_prefix`` + + ``embed_suffix``: every block bidirectional internally, later + blocks visible to earlier ones via the cumulative-block rule. + """ + assert sum(block_sizes) == seq_len + att_marks = [] + for i, n in enumerate(block_sizes): + att_marks += [1 if i > 0 else 0] + [0] * (n - 1) + pad = torch.ones(bsize, seq_len, dtype=torch.bool) + att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len) + att_2d = make_att_2d_masks(pad, att) + bias = torch.where( + att_2d[:, None, :, :], + torch.zeros((), dtype=dtype), + torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype), + ) + return bias + + +@pytest.mark.parametrize( + "num_heads,num_kv_heads,head_dim", + [ + (8, 1, 256), # gemma_2b / paligemma config + (8, 8, 64), # MHA control (no GQA repeat) + ], +) +def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim): + """SDPA forward output matches the eager softmax(QK^T)@V on the + block-bidirectional mask layout pi05 actually uses.""" + bsize, seq_len = 2, 13 + block_sizes = [4, 5, 4] # images, language, suffix-style blocks + dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol + scaling = head_dim ** -0.5 + + q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype) + mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype) + + module = _mock_self_attn(num_heads // num_kv_heads) + + out_eager, _ = modeling_gemma.eager_attention_forward( + module, q, k, v, mask, scaling + ) + out_sdpa, _ = sdpa_attention_forward( + module, q, k, v, mask, scaling + ) + assert out_eager.shape == out_sdpa.shape + torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4) + + +def test_sdpa_parity_bf16(): + """bf16 path — looser tolerance, must still match eager.""" + bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256 + scaling = head_dim ** -0.5 + q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16) + mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16) + module = _mock_self_attn(num_heads // num_kv_heads) + + out_eager, _ = modeling_gemma.eager_attention_forward( + module, q, k, v, mask, scaling + ) + out_sdpa, _ = sdpa_attention_forward( + module, q, k, v, mask, scaling + ) + torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2) + + +def test_sdpa_parity_backward(): + """Gradients flow through SDPA and match the eager path within + bf16 tolerance — critical for any training-side parity claim.""" + bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32 + scaling = head_dim ** -0.5 + q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32) + q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True) + mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32) + module = _mock_self_attn(num_heads // num_kv_heads) + + out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling) + g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v]) + + out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling) + g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v]) + + torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)