mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 23:19:48 +00:00
pi052: drop `use_hf_kernels` flag — always patch Liger kernels
The flag gated a process-global, idempotent Liger patch that swaps in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time on H100, bench job 22161421). Since liger-kernel is now a hard dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce`` in ``modeling_pi052``), gating the same dep behind an opt-in flag was redundant — every pi052 run pulls the wheel in either way. * ``PI052Policy.__init__`` calls ``_enable_hf_kernels()`` unconditionally; the function still degrades gracefully if the wheel happens to be missing (logs a warning, returns). * Drop ``PI052Config.use_hf_kernels``; the bench numbers and the ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` / ``_fast_lin_ce`` are kept as comments next to the docstring. * Update the warning + ``_shifted_lin_ce`` lazy-import comment to drop stale ``use_hf_kernels`` / ``reduce-overhead`` references. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -190,30 +190,13 @@ class PI052Config(PI05Config):
|
|||||||
# commonly cited weight; set 0 to disable entirely.
|
# commonly cited weight; set 0 to disable entirely.
|
||||||
text_ce_z_loss_weight: float = 1e-4
|
text_ce_z_loss_weight: float = 1e-4
|
||||||
|
|
||||||
# Fused kernels (Liger via HF kernels lib) ---------------------------
|
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||||
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels
|
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||||
# before the model is built. Measured on H100 80GB at BS=16 / L=512
|
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||||
# with KI+GC on (bench job 22161421, see
|
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||||
# ``examples/benchmark/bench_pi052_kernels.slurm``):
|
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||||
#
|
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||||
# rope only → −2.5% step time
|
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||||
# geglu only → −2.2% step time
|
|
||||||
# layer_norm only → −1.1% step time
|
|
||||||
# all three → −4.5% step time, peak_mem unchanged
|
|
||||||
#
|
|
||||||
# ``fused_linear_cross_entropy`` is now wired directly into the
|
|
||||||
# pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see
|
|
||||||
# ``modeling_pi052``). The kernel takes ``(hidden_states,
|
|
||||||
# lm_head.weight, labels)`` and computes matmul + softmax + CE in
|
|
||||||
# fused Triton blocks, never materialising the (B, T, 257k) logits
|
|
||||||
# tensor. Saves ~10 GB activation memory per CE branch and ~30 %
|
|
||||||
# step time on the dual-CE pi052 recipe (text + FAST). Removing the
|
|
||||||
# ``.any().item()`` sync also lets ``compile_mode=reduce-overhead``
|
|
||||||
# capture full CUDA graphs over the loss path.
|
|
||||||
use_hf_kernels: bool = False
|
|
||||||
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
|
|
||||||
fused Triton kernels (rope + geglu + layer_norm). Off by default;
|
|
||||||
requires ``pip install liger-kernel``."""
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None:
|
|||||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"PI052: use_hf_kernels=True but liger-kernel is not installed; "
|
"PI052: liger-kernel is not installed; skipping fused Triton "
|
||||||
"skipping. Install with `pip install liger-kernel`."
|
"kernels (rope/geglu/layer_norm). Install with "
|
||||||
|
"``pip install liger-kernel`` for a ~4.5%% step speedup."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
apply_liger_kernel_to_paligemma(
|
apply_liger_kernel_to_paligemma(
|
||||||
@@ -126,15 +127,14 @@ def _shifted_lin_ce(
|
|||||||
* Shift convention identical to the eager version — hidden at
|
* Shift convention identical to the eager version — hidden at
|
||||||
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
|
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
|
||||||
* No ``.any().item()`` sync — Liger returns 0.0 cleanly when
|
* No ``.any().item()`` sync — Liger returns 0.0 cleanly when
|
||||||
every label is ignored, keeping the graph capturable for
|
every label is ignored.
|
||||||
``compile_mode=reduce-overhead`` (CUDA graphs).
|
|
||||||
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
|
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
|
||||||
(same ``z²·w`` formula on per-position logsumexp). Setting it
|
(same ``z²·w`` formula on per-position logsumexp). Setting it
|
||||||
to 0 disables the z-loss term at zero cost.
|
to 0 disables the z-loss term at zero cost.
|
||||||
"""
|
"""
|
||||||
# Liger is imported lazily so the module still imports on machines
|
# Liger is imported lazily so the module still imports on machines
|
||||||
# without liger-kernel; the call site only ever runs after
|
# without liger-kernel — the call site only fires from the training
|
||||||
# use_hf_kernels / training has selected the Liger path.
|
# forward, which always pulls in the kernel.
|
||||||
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
|
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
|
||||||
LigerFusedLinearCrossEntropyLoss,
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
)
|
)
|
||||||
@@ -434,9 +434,10 @@ class PI052Policy(PI05Policy):
|
|||||||
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
||||||
# Patch ops BEFORE the backbone is built (super().__init__ below
|
# Patch ops BEFORE the backbone is built (super().__init__ below
|
||||||
# constructs PaliGemmaWithExpertModel which instantiates the
|
# constructs PaliGemmaWithExpertModel which instantiates the
|
||||||
# Gemma/Siglip layers we want to swap).
|
# Gemma/Siglip layers we want to swap). Always-on — the patch
|
||||||
if getattr(config, "use_hf_kernels", False):
|
# is process-global / idempotent and degrades gracefully if
|
||||||
_enable_hf_kernels()
|
# liger-kernel is missing.
|
||||||
|
_enable_hf_kernels()
|
||||||
|
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
||||||
|
|||||||
Reference in New Issue
Block a user