mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
4913356564
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call with ``torch.nn.functional.scaled_dot_product_attention`` in ``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052). PyTorch SDPA picks the memory-efficient kernel for the block-bidirectional 4D additive mask the dual-expert model uses (FA2 / FA3 reject it because they only accept causal / sliding-window / varlen patterns). The shared ``sdpa_attention_forward`` helper mirrors the eager signature so the call sites are unchanged. Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)`` wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside ``PaliGemmaWithExpertModel.forward`` already handles activation recompute; the outer wrap was double-recomputing the whole backbone. +14% steps/sec on its own (job 22161405 vs 22161398, 1xH100). groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub`` rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses, which was blocking imports of any sibling policy through ``lerobot.policies.factory``. New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8) times PI052Policy.forward+backward (optionally with AdamW) on synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True, L=512, 4.14 B trainable params, AdamW state in bf16: pre-SDPA eager BS=8 610ms 19.5 GiB -> 13.1 samples/s sdpa BS=8 + compile=default 413ms 19.5 GiB -> 19.3 samples/s sdpa BS=16 + compile=default 715ms 37.3 GiB -> 22.4 samples/s sdpa BS=32 + compile=default 1325ms 44.8 GiB -> 24.2 samples/s sdpa BS=40 + compile=default 1665ms 48.6 GiB -> 24.0 samples/s Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py`` cover fp32 / bf16 / GQA / MHA forward + backward — output and grads match the eager path within bf16 tolerance. Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the follow-up multi-GPU memory sharding work. Co-authored-by: Cursor <cursoragent@cursor.com>
339 lines
12 KiB
Python
339 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
|
|
|
|
Compares the new SDPA attention path against the eager baseline by
|
|
monkeypatching ``sdpa_attention_forward`` before the first model
|
|
forward — so both runs share identical Q/K/V plumbing and only the
|
|
attention kernel differs. Reports steps/sec and peak GPU memory.
|
|
|
|
SLURM-only:
|
|
|
|
sbatch examples/benchmark/bench_pi052_step.slurm
|
|
|
|
Or one-off:
|
|
|
|
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
|
|
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import gc
|
|
import math
|
|
import os
|
|
import time
|
|
|
|
import torch
|
|
|
|
|
|
def _maybe_patch_eager() -> None:
|
|
"""Swap ``sdpa_attention_forward`` for the original eager forward.
|
|
|
|
Must be called BEFORE PI052Policy is instantiated — the layer
|
|
compute functions resolve the symbol at call time (module-level
|
|
lookup), so this patch covers both pi05 and pi052 KI paths."""
|
|
from transformers.models.gemma import modeling_gemma
|
|
|
|
from lerobot.policies.pi05 import modeling_pi05
|
|
|
|
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
|
|
|
|
|
|
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
|
|
|
|
|
|
def _maybe_patch_liger(spec: str) -> dict:
|
|
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
|
|
|
|
Must be called BEFORE PI052Policy is instantiated — Liger replaces
|
|
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
|
|
so any model built after the call picks up the fused forwards.
|
|
|
|
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
|
|
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
|
|
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
|
|
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
|
|
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
|
|
so neither patch would fire without invasive model-code changes.
|
|
"""
|
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
|
|
if spec in ("", "none"):
|
|
return enabled
|
|
tokens = [t.strip() for t in spec.split(",") if t.strip()]
|
|
if tokens == ["all"]:
|
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
|
|
else:
|
|
for t in tokens:
|
|
if t not in enabled:
|
|
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
|
|
enabled[t] = True
|
|
|
|
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
|
|
|
|
apply_liger_kernel_to_paligemma(
|
|
rope=enabled["rope"],
|
|
rms_norm=enabled["rms_norm"],
|
|
geglu=enabled["geglu"],
|
|
layer_norm=enabled["layer_norm"],
|
|
cross_entropy=False,
|
|
fused_linear_cross_entropy=False,
|
|
)
|
|
return enabled
|
|
|
|
|
|
def _maybe_patch_flex() -> None:
|
|
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
|
|
|
|
Experimental: builds a per-call ``score_mod`` from the additive
|
|
mask and dispatches to a compiled ``flex_attention`` kernel.
|
|
|
|
Known issue on torch 2.7.1: dynamo errors out with
|
|
``FlexAttentionHigherOrderVariable() has no type`` when the
|
|
``score_mod`` closure captures a per-call bias tensor. A proper
|
|
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
|
|
PI05Pytorch.forward level so a BlockMask object can be passed
|
|
down to the layer compute, not a per-call closure. Left as
|
|
future work; keep this stub for benchmark experimentation."""
|
|
import torch
|
|
from torch.nn.attention.flex_attention import flex_attention
|
|
|
|
from lerobot.policies.pi05 import modeling_pi05
|
|
|
|
compiled_flex = torch.compile(flex_attention, dynamic=True)
|
|
|
|
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
|
|
n_rep = module.num_key_value_groups
|
|
if n_rep > 1:
|
|
key = key.repeat_interleave(n_rep, dim=1)
|
|
value = value.repeat_interleave(n_rep, dim=1)
|
|
|
|
bias = attention_mask # (B, 1, Lq, Lk) additive
|
|
|
|
def score_mod(score, b, h, q_idx, kv_idx):
|
|
return score + bias[b, 0, q_idx, kv_idx]
|
|
|
|
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
|
|
return attn_output.transpose(1, 2).contiguous(), None
|
|
|
|
modeling_pi05.sdpa_attention_forward = flex_forward
|
|
|
|
|
|
def _build_policy(args, device: torch.device):
|
|
"""Random-init PI052Policy at production-relevant shapes."""
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.policies.pi052.configuration_pi052 import PI052Config
|
|
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
|
|
|
|
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
|
|
# which flips ``train_expert_only=False`` in __post_init__ and
|
|
# makes the whole PaliGemma + Gemma-expert stack trainable. We
|
|
# mirror that here so the optimizer-state count reflects reality;
|
|
# the loss path still goes through ``PI05Policy.forward`` because
|
|
# ``text_labels`` / FAST tokens are absent from the synthetic batch
|
|
# (see ``PI052Policy.forward`` early-return).
|
|
config = PI052Config(
|
|
max_action_dim=args.action_dim,
|
|
max_state_dim=args.state_dim,
|
|
dtype=args.dtype,
|
|
knowledge_insulation=args.knowledge_insulation,
|
|
text_loss_weight=1e-3 if args.train_full else 0.0,
|
|
flow_loss_weight=1.0,
|
|
enable_fast_action_loss=False,
|
|
unfreeze_lm_head=args.train_full,
|
|
tokenizer_max_length=args.lang_tokens,
|
|
device="cuda",
|
|
compile_model=args.compile_model,
|
|
compile_mode=args.compile_mode,
|
|
)
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
|
|
}
|
|
policy = PI052Policy(config)
|
|
policy.to(device)
|
|
if args.gradient_checkpointing:
|
|
policy.model.gradient_checkpointing_enable()
|
|
policy.train()
|
|
return policy, config
|
|
|
|
|
|
def _build_batch(args, config, device: torch.device) -> dict:
|
|
"""Synthetic batch matching the training-loop input contract."""
|
|
from lerobot.utils.constants import (
|
|
ACTION,
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
|
OBS_LANGUAGE_TOKENS,
|
|
)
|
|
|
|
B = args.batch_size
|
|
L = args.lang_tokens
|
|
return {
|
|
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
|
|
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
|
|
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
|
|
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
|
|
"observation.state": torch.randn(B, args.state_dim, device=device),
|
|
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
|
|
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
|
|
"task": ["bench task"] * B,
|
|
}
|
|
|
|
|
|
def _step(policy, batch, optimizer=None) -> torch.Tensor:
|
|
loss, _ = policy.forward(batch)
|
|
loss.backward()
|
|
if optimizer is not None:
|
|
optimizer.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
else:
|
|
for p in policy.parameters():
|
|
if p.grad is not None:
|
|
p.grad = None
|
|
return loss.detach()
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
|
|
parser.add_argument(
|
|
"--kernels",
|
|
default="none",
|
|
help=(
|
|
"Liger sub-kernels to enable, comma-separated. Choose from "
|
|
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
|
|
"via apply_liger_kernel_to_paligemma() BEFORE model build."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--compile",
|
|
dest="compile_model",
|
|
action="store_true",
|
|
help="Set policy.config.compile_model=True (torch.compile the forward).",
|
|
)
|
|
parser.add_argument(
|
|
"--compile-mode",
|
|
default="default",
|
|
help="torch.compile mode (default | reduce-overhead | max-autotune).",
|
|
)
|
|
parser.add_argument("--batch-size", type=int, default=8)
|
|
parser.add_argument("--warmup", type=int, default=8)
|
|
parser.add_argument("--steps", type=int, default=40)
|
|
parser.add_argument("--lang-tokens", type=int, default=512)
|
|
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
|
|
parser.add_argument("--action-dim", type=int, default=14)
|
|
parser.add_argument("--state-dim", type=int, default=14)
|
|
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
|
|
parser.add_argument(
|
|
"--gradient-checkpointing",
|
|
dest="gradient_checkpointing",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
)
|
|
parser.add_argument(
|
|
"--optimizer",
|
|
choices=["none", "adamw", "adamw_fused"],
|
|
default="adamw_fused",
|
|
help=(
|
|
"Whether to include an AdamW step in the timed iteration. "
|
|
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
|
|
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
|
|
"state and ``optimizer.step()`` cost."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--train-full",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help=(
|
|
"Mirror production: unfreeze the PaliGemma backbone (full "
|
|
"~3B trainable params) instead of training only the 300M "
|
|
"action expert."
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not torch.cuda.is_available():
|
|
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
|
|
|
|
if args.attn == "eager":
|
|
_maybe_patch_eager()
|
|
elif args.attn == "flex":
|
|
_maybe_patch_flex()
|
|
|
|
liger_flags = _maybe_patch_liger(args.kernels)
|
|
|
|
device = torch.device("cuda")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
policy, config = _build_policy(args, device)
|
|
batch = _build_batch(args, config, device)
|
|
|
|
optimizer = None
|
|
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
|
if args.optimizer != "none":
|
|
trainable = [p for p in policy.parameters() if p.requires_grad]
|
|
optimizer = torch.optim.AdamW(
|
|
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
|
|
)
|
|
|
|
for _ in range(args.warmup):
|
|
_step(policy, batch, optimizer)
|
|
torch.cuda.synchronize()
|
|
|
|
starter = torch.cuda.Event(enable_timing=True)
|
|
ender = torch.cuda.Event(enable_timing=True)
|
|
starter.record()
|
|
for _ in range(args.steps):
|
|
_step(policy, batch, optimizer)
|
|
ender.record()
|
|
torch.cuda.synchronize()
|
|
total_ms = starter.elapsed_time(ender)
|
|
step_ms = total_ms / args.steps
|
|
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
|
optim_gb = 0.0
|
|
if optimizer is not None:
|
|
for st in optimizer.state.values():
|
|
for v in st.values():
|
|
if torch.is_tensor(v):
|
|
optim_gb += v.numel() * v.element_size() / (1024**3)
|
|
|
|
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
|
|
name = (
|
|
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
|
|
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
|
|
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
|
|
)
|
|
print(
|
|
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
|
|
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
|
|
f"trainable_params={trainable_params / 1e9:.2f}B"
|
|
)
|
|
|
|
del policy, batch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|