pi052: drop `use_hf_kernels` flag — always patch Liger kernels

The flag gated a process-global, idempotent Liger patch that swaps
in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time
on H100, bench job 22161421). Since liger-kernel is now a hard
dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce``
in ``modeling_pi052``), gating the same dep behind an opt-in flag
was redundant — every pi052 run pulls the wheel in either way.

* ``PI052Policy.__init__`` calls ``_enable_hf_kernels()``
  unconditionally; the function still degrades gracefully if the
  wheel happens to be missing (logs a warning, returns).
* Drop ``PI052Config.use_hf_kernels``; the bench numbers and the
  ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` /
  ``_fast_lin_ce`` are kept as comments next to the docstring.
* Update the warning + ``_shifted_lin_ce`` lazy-import comment to
  drop stale ``use_hf_kernels`` / ``reduce-overhead`` references.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-26 11:47:49 +00:00
parent 4c3ddb1ff5
commit d70c810416
2 changed files with 17 additions and 33 deletions
@@ -190,30 +190,13 @@ class PI052Config(PI05Config):
# commonly cited weight; set 0 to disable entirely. # commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4 text_ce_z_loss_weight: float = 1e-4
# Fused kernels (Liger via HF kernels lib) --------------------------- # Liger Triton kernels (rope + geglu + layer_norm) are now patched
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels # unconditionally at model build time — see ``_enable_hf_kernels``
# before the model is built. Measured on H100 80GB at BS=16 / L=512 # in ``modeling_pi052``. The patch is process-global, idempotent
# with KI+GC on (bench job 22161421, see # and degrades gracefully if ``liger-kernel`` is missing. Measured
# ``examples/benchmark/bench_pi052_kernels.slurm``): # at -4.5% step time on H100 (bench job 22161421); peak memory
# # unchanged. ``fused_linear_cross_entropy`` ships separately via
# rope only → 2.5% step time # ``_shifted_lin_ce`` / ``_fast_lin_ce``.
# geglu only → 2.2% step time
# layer_norm only → 1.1% step time
# all three → 4.5% step time, peak_mem unchanged
#
# ``fused_linear_cross_entropy`` is now wired directly into the
# pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see
# ``modeling_pi052``). The kernel takes ``(hidden_states,
# lm_head.weight, labels)`` and computes matmul + softmax + CE in
# fused Triton blocks, never materialising the (B, T, 257k) logits
# tensor. Saves ~10 GB activation memory per CE branch and ~30 %
# step time on the dual-CE pi052 recipe (text + FAST). Removing the
# ``.any().item()`` sync also lets ``compile_mode=reduce-overhead``
# capture full CUDA graphs over the loss path.
use_hf_kernels: bool = False
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
fused Triton kernels (rope + geglu + layer_norm). Off by default;
requires ``pip install liger-kernel``."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
+10 -9
View File
@@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None:
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415 from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
except ImportError: except ImportError:
logger.warning( logger.warning(
"PI052: use_hf_kernels=True but liger-kernel is not installed; " "PI052: liger-kernel is not installed; skipping fused Triton "
"skipping. Install with `pip install liger-kernel`." "kernels (rope/geglu/layer_norm). Install with "
"``pip install liger-kernel`` for a ~4.5%% step speedup."
) )
return return
apply_liger_kernel_to_paligemma( apply_liger_kernel_to_paligemma(
@@ -126,15 +127,14 @@ def _shifted_lin_ce(
* Shift convention identical to the eager version — hidden at * Shift convention identical to the eager version — hidden at
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``. position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
* No ``.any().item()`` sync — Liger returns 0.0 cleanly when * No ``.any().item()`` sync — Liger returns 0.0 cleanly when
every label is ignored, keeping the graph capturable for every label is ignored.
``compile_mode=reduce-overhead`` (CUDA graphs).
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale`` * ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
(same ``z²·w`` formula on per-position logsumexp). Setting it (same ``z²·w`` formula on per-position logsumexp). Setting it
to 0 disables the z-loss term at zero cost. to 0 disables the z-loss term at zero cost.
""" """
# Liger is imported lazily so the module still imports on machines # Liger is imported lazily so the module still imports on machines
# without liger-kernel; the call site only ever runs after # without liger-kernel the call site only fires from the training
# use_hf_kernels / training has selected the Liger path. # forward, which always pulls in the kernel.
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415 from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss, LigerFusedLinearCrossEntropyLoss,
) )
@@ -434,9 +434,10 @@ class PI052Policy(PI05Policy):
def __init__(self, config: PI052Config, **kwargs: Any) -> None: def __init__(self, config: PI052Config, **kwargs: Any) -> None:
# Patch ops BEFORE the backbone is built (super().__init__ below # Patch ops BEFORE the backbone is built (super().__init__ below
# constructs PaliGemmaWithExpertModel which instantiates the # constructs PaliGemmaWithExpertModel which instantiates the
# Gemma/Siglip layers we want to swap). # Gemma/Siglip layers we want to swap). Always-on — the patch
if getattr(config, "use_hf_kernels", False): # is process-global / idempotent and degrades gracefully if
_enable_hf_kernels() # liger-kernel is missing.
_enable_hf_kernels()
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and