mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
pi052: drop `use_hf_kernels` flag — always patch Liger kernels
The flag gated a process-global, idempotent Liger patch that swaps in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time on H100, bench job 22161421). Since liger-kernel is now a hard dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce`` in ``modeling_pi052``), gating the same dep behind an opt-in flag was redundant — every pi052 run pulls the wheel in either way. * ``PI052Policy.__init__`` calls ``_enable_hf_kernels()`` unconditionally; the function still degrades gracefully if the wheel happens to be missing (logs a warning, returns). * Drop ``PI052Config.use_hf_kernels``; the bench numbers and the ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` / ``_fast_lin_ce`` are kept as comments next to the docstring. * Update the warning + ``_shifted_lin_ce`` lazy-import comment to drop stale ``use_hf_kernels`` / ``reduce-overhead`` references. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -190,30 +190,13 @@ class PI052Config(PI05Config):
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Fused kernels (Liger via HF kernels lib) ---------------------------
|
||||
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels
|
||||
# before the model is built. Measured on H100 80GB at BS=16 / L=512
|
||||
# with KI+GC on (bench job 22161421, see
|
||||
# ``examples/benchmark/bench_pi052_kernels.slurm``):
|
||||
#
|
||||
# rope only → −2.5% step time
|
||||
# geglu only → −2.2% step time
|
||||
# layer_norm only → −1.1% step time
|
||||
# all three → −4.5% step time, peak_mem unchanged
|
||||
#
|
||||
# ``fused_linear_cross_entropy`` is now wired directly into the
|
||||
# pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see
|
||||
# ``modeling_pi052``). The kernel takes ``(hidden_states,
|
||||
# lm_head.weight, labels)`` and computes matmul + softmax + CE in
|
||||
# fused Triton blocks, never materialising the (B, T, 257k) logits
|
||||
# tensor. Saves ~10 GB activation memory per CE branch and ~30 %
|
||||
# step time on the dual-CE pi052 recipe (text + FAST). Removing the
|
||||
# ``.any().item()`` sync also lets ``compile_mode=reduce-overhead``
|
||||
# capture full CUDA graphs over the loss path.
|
||||
use_hf_kernels: bool = False
|
||||
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
|
||||
fused Triton kernels (rope + geglu + layer_norm). Off by default;
|
||||
requires ``pip install liger-kernel``."""
|
||||
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
@@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"PI052: use_hf_kernels=True but liger-kernel is not installed; "
|
||||
"skipping. Install with `pip install liger-kernel`."
|
||||
"PI052: liger-kernel is not installed; skipping fused Triton "
|
||||
"kernels (rope/geglu/layer_norm). Install with "
|
||||
"``pip install liger-kernel`` for a ~4.5%% step speedup."
|
||||
)
|
||||
return
|
||||
apply_liger_kernel_to_paligemma(
|
||||
@@ -126,15 +127,14 @@ def _shifted_lin_ce(
|
||||
* Shift convention identical to the eager version — hidden at
|
||||
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
|
||||
* No ``.any().item()`` sync — Liger returns 0.0 cleanly when
|
||||
every label is ignored, keeping the graph capturable for
|
||||
``compile_mode=reduce-overhead`` (CUDA graphs).
|
||||
every label is ignored.
|
||||
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
|
||||
(same ``z²·w`` formula on per-position logsumexp). Setting it
|
||||
to 0 disables the z-loss term at zero cost.
|
||||
"""
|
||||
# Liger is imported lazily so the module still imports on machines
|
||||
# without liger-kernel; the call site only ever runs after
|
||||
# use_hf_kernels / training has selected the Liger path.
|
||||
# without liger-kernel — the call site only fires from the training
|
||||
# forward, which always pulls in the kernel.
|
||||
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
|
||||
LigerFusedLinearCrossEntropyLoss,
|
||||
)
|
||||
@@ -434,9 +434,10 @@ class PI052Policy(PI05Policy):
|
||||
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
||||
# Patch ops BEFORE the backbone is built (super().__init__ below
|
||||
# constructs PaliGemmaWithExpertModel which instantiates the
|
||||
# Gemma/Siglip layers we want to swap).
|
||||
if getattr(config, "use_hf_kernels", False):
|
||||
_enable_hf_kernels()
|
||||
# Gemma/Siglip layers we want to swap). Always-on — the patch
|
||||
# is process-global / idempotent and degrades gracefully if
|
||||
# liger-kernel is missing.
|
||||
_enable_hf_kernels()
|
||||
|
||||
super().__init__(config, **kwargs)
|
||||
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
||||
|
||||
Reference in New Issue
Block a user