mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
4913356564
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call with ``torch.nn.functional.scaled_dot_product_attention`` in ``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052). PyTorch SDPA picks the memory-efficient kernel for the block-bidirectional 4D additive mask the dual-expert model uses (FA2 / FA3 reject it because they only accept causal / sliding-window / varlen patterns). The shared ``sdpa_attention_forward`` helper mirrors the eager signature so the call sites are unchanged. Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)`` wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside ``PaliGemmaWithExpertModel.forward`` already handles activation recompute; the outer wrap was double-recomputing the whole backbone. +14% steps/sec on its own (job 22161405 vs 22161398, 1xH100). groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub`` rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses, which was blocking imports of any sibling policy through ``lerobot.policies.factory``. New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8) times PI052Policy.forward+backward (optionally with AdamW) on synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True, L=512, 4.14 B trainable params, AdamW state in bf16: pre-SDPA eager BS=8 610ms 19.5 GiB -> 13.1 samples/s sdpa BS=8 + compile=default 413ms 19.5 GiB -> 19.3 samples/s sdpa BS=16 + compile=default 715ms 37.3 GiB -> 22.4 samples/s sdpa BS=32 + compile=default 1325ms 44.8 GiB -> 24.2 samples/s sdpa BS=40 + compile=default 1665ms 48.6 GiB -> 24.0 samples/s Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py`` cover fp32 / bf16 / GQA / MHA forward + backward — output and grads match the eager path within bf16 tolerance. Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the follow-up multi-GPU memory sharding work. Co-authored-by: Cursor <cursoragent@cursor.com>
75 lines
2.9 KiB
Bash
75 lines
2.9 KiB
Bash
#!/bin/bash
|
|
#SBATCH --job-name=bench-pi052-kernels
|
|
#SBATCH --partition=hopper-prod
|
|
#SBATCH --qos=high
|
|
#SBATCH --time=01:30:00
|
|
#SBATCH --ntasks=1
|
|
#SBATCH --gpus-per-task=1
|
|
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out
|
|
|
|
# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma.
|
|
# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on.
|
|
# Same harness as bench_pi052_step.py — only the --kernels flag varies
|
|
# across runs so any delta is attributable to the patched op(s).
|
|
#
|
|
# Subkernels exercised: rope, rms_norm, geglu, layer_norm.
|
|
# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls
|
|
# F.cross_entropy directly and bypasses PaliGemma's forward, so those
|
|
# patches wouldn't fire without model-code changes (separate PR).
|
|
|
|
set -euo pipefail
|
|
|
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
|
|
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
|
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
|
|
|
# /fsx triton cache is shared across nodes with different glibc versions
|
|
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
|
# a node-local cache per job to side-step that.
|
|
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
|
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
|
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
|
|
|
echo "=== Node: $(hostname) ==="
|
|
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
|
ldd --version | head -1
|
|
|
|
# Liger isn't in our standard env yet — install on the compute node so
|
|
# the slurm log captures the exact version that produced the numbers.
|
|
python -m pip install -q --upgrade 'liger-kernel'
|
|
python - <<'PY' || true
|
|
from importlib.metadata import version, PackageNotFoundError
|
|
try:
|
|
print("liger-kernel", version("liger-kernel"))
|
|
except PackageNotFoundError:
|
|
print("liger-kernel: not importable")
|
|
import liger_kernel.transformers as t
|
|
print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma"))
|
|
PY
|
|
|
|
run() {
|
|
echo
|
|
echo "--- $* ---"
|
|
python examples/benchmark/bench_pi052_step.py "$@" || true
|
|
}
|
|
|
|
# -- Baseline (no kernels) at the BS we actually train at. --
|
|
run --attn sdpa --batch-size 8 --kernels none
|
|
run --attn sdpa --batch-size 16 --kernels none
|
|
|
|
# -- Per-subkernel ablations at BS=16 to isolate each contributor. --
|
|
run --attn sdpa --batch-size 16 --kernels rms_norm
|
|
run --attn sdpa --batch-size 16 --kernels geglu
|
|
run --attn sdpa --batch-size 16 --kernels layer_norm
|
|
run --attn sdpa --batch-size 16 --kernels rope
|
|
|
|
# -- All-on, both BS to compare against the matched baselines above. --
|
|
run --attn sdpa --batch-size 8 --kernels all
|
|
run --attn sdpa --batch-size 16 --kernels all
|
|
|
|
# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? --
|
|
run --attn sdpa --batch-size 24 --kernels none
|
|
run --attn sdpa --batch-size 24 --kernels all
|