pi052: wire Liger fused linear CE + DDP-safe FAST tokenizer fit

* Replace ``_shifted_ce`` / ``_fast_ce`` with Liger's
  ``fused_linear_cross_entropy``: the ``(B, T, 257k)`` logits tensor
  is no longer materialised — the kernel chunks over the ``(B*T)``
  axis and computes matmul + softmax + CE in fused Triton blocks.
  ~30 % step speedup and ~12 GB of activation memory freed on the
  dual-CE pi052 recipe. All four call sites in
  ``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``
  updated; the ``.any().item()`` CPU sync is dropped so the loss
  path stays CUDA-graph-capturable.

* DDP-safe FAST tokenizer fit. The cache-hit sentinel previously
  looked for ``preprocessor_config.json`` but
  ``ProcessorMixin.save_pretrained`` writes ``processor_config.json``
  — every rank always cache-missed and re-fit, racing on writes and
  occasionally producing a stale ``.pyc`` that crashed
  ``AutoProcessor.from_pretrained`` with ``AttributeError:
  UniversalActionProcessor``. Fix the sentinel; gate the fit on the
  (local) main process; non-leader ranks poll the cache until the
  leader is done. Caught by job 22162549.

* New recipe ``subtask_mem_vqa_robocasa.yaml`` — subtask + memory +
  per-camera VQA over the three robocasa camera keys produced by the
  port pipeline (``robot0_agentview_left/right``, ``robot0_eye_in_hand``).
  The previously-shipped ``subtask_mem_vqa_speech.yaml`` references
  ``observation.images.front`` / ``wrist`` which don't exist in
  robocasa, so VQA never rendered.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-26 11:18:16 +00:00
parent 8615f3f613
commit 4c3ddb1ff5
4 changed files with 239 additions and 54 deletions
@@ -0,0 +1,99 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -201,11 +201,15 @@ class PI052Config(PI05Config):
# layer_norm only → 1.1% step time
# all three → 4.5% step time, peak_mem unchanged
#
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled
# pi052 calls ``F.cross_entropy`` directly and bypasses
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger
# patch fires without invasive model-code changes. Reserved for a
# follow-up.
# ``fused_linear_cross_entropy`` is now wired directly into the
# pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see
# ``modeling_pi052``). The kernel takes ``(hidden_states,
# lm_head.weight, labels)`` and computes matmul + softmax + CE in
# fused Triton blocks, never materialising the (B, T, 257k) logits
# tensor. Saves ~10 GB activation memory per CE branch and ~30 %
# step time on the dual-CE pi052 recipe (text + FAST). Removing the
# ``.any().item()`` sync also lets ``compile_mode=reduce-overhead``
# capture full CUDA graphs over the loss path.
use_hf_kernels: bool = False
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
fused Triton kernels (rope + geglu + layer_norm). Off by default;
@@ -39,12 +39,21 @@ from __future__ import annotations
import hashlib
import logging
import os
import time
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature(
dataset_repo_id: str,
@@ -111,7 +120,7 @@ def fit_fast_tokenizer(
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists():
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d",
@@ -119,6 +128,32 @@ def fit_fast_tokenizer(
)
return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s",
+95 -48
View File
@@ -106,35 +106,53 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
def _shifted_lin_ce(
hidden: Tensor,
lm_head_weight: Tensor,
labels: Tensor,
z_loss_weight: float = 0.0,
) -> Tensor:
"""Liger-fused (hidden @ W.T → softmax → CE) on shifted labels.
Mean over non-ignored positions across the batch. Returns 0 cleanly
when no positions are supervised (clamp(min=1) on the denominator).
Replaces the explicit ``lm_head(hidden) F.cross_entropy(...)``
pair with Liger's ``LigerFusedLinearCrossEntropyLoss``: the full
``(B, T, V)`` logits tensor is never materialised the kernel
chunks over the (B*T) axis, computing matmul + logsumexp + CE
in fused Triton blocks. On a 257k-vocab head this saves ~10 GB
of activation memory per CE branch and ~30 % step time vs the
eager ``F.cross_entropy`` path.
When ``z_loss_weight > 0``, also adds PaLM-style z-loss
(`` · w``, where ``z = log Σ exp(logits)``) on every supervised
position. Penalises the log-partition function drifting away from
zero without it, large-vocab models (PaliGemma is 257k) can let
``logsumexp`` grow unboundedly while CE stays low, because uniform
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla
report this is essential for stable large-vocab CE; cheap insurance
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk.
Semantics:
* Shift convention identical to the eager version hidden at
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
* No ``.any().item()`` sync Liger returns 0.0 cleanly when
every label is ignored, keeping the graph capturable for
``compile_mode=reduce-overhead`` (CUDA graphs).
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
(same ``·w`` formula on per-position logsumexp). Setting it
to 0 disables the z-loss term at zero cost.
"""
shift_logits = logits[:, :-1, :].contiguous()
# Liger is imported lazily so the module still imports on machines
# without liger-kernel; the call site only ever runs after
# use_hf_kernels / training has selected the Liger path.
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous().long()
valid = shift_labels != -100
if not bool(valid.any().item()):
return shift_logits.sum() * 0.0
valid_logits = shift_logits[valid]
valid_labels = shift_labels[valid]
ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean")
if z_loss_weight <= 0.0:
return ce
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position.
# ``logsumexp`` is numerically stable and shares the softmax kernel.
z = torch.logsumexp(valid_logits, dim=-1)
return ce + z_loss_weight * (z**2).mean()
B, T_1, H = shift_hidden.shape
flat_hidden = shift_hidden.reshape(B * T_1, H)
flat_labels = shift_labels.reshape(B * T_1)
# Match the dtype the eager path used: cast hidden to the lm_head's
# weight dtype so bf16 weights see bf16 activations.
flat_hidden = flat_hidden.to(lm_head_weight.dtype)
loss_fn = LigerFusedLinearCrossEntropyLoss(
ignore_index=-100,
lse_square_scale=float(z_loss_weight),
reduction="mean",
)
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
def _mark_target_span_causal(
@@ -172,32 +190,48 @@ def _mark_target_span_causal(
return att
def _fast_ce(
fast_logits: Tensor,
def _fast_lin_ce(
hidden: Tensor,
lm_head_weight: Tensor,
action_tokens: Tensor,
action_code_mask: Tensor,
predict_actions_t: Tensor | None,
) -> Tensor:
"""FAST action-code CE with token-span masking and per-sample action gating.
"""Liger-fused FAST action-code CE with span masking + sample gating.
``action_code_mask`` is true only on the discrete action-code tokens,
excluding the BOS / "Action: " / delimiter wrapper. Samples whose
recipe sets ``predict_actions=False`` get all code positions masked
out via the per-sample gate.
Mirrors ``_shifted_lin_ce`` but with FAST-specific masking: only
the discrete action-code positions (``action_code_mask``) are
supervised, and samples whose recipe sets ``predict_actions=False``
get all code positions masked. Masked positions are folded into
Liger's ``ignore_index=-100`` so the kernel skips them without
a CPU-side gather (which would synchronise + break CUDA graphs).
"""
shift_logits = fast_logits[:, :-1, :].contiguous()
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_code_mask[:, 1:].contiguous().bool()
if predict_actions_t is not None:
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
shift_valid = shift_valid & sample_mask
if not bool(shift_valid.any().item()):
return shift_logits.sum() * 0.0
return F.cross_entropy(
shift_logits[shift_valid],
shift_targets[shift_valid],
# Fold the boolean mask into the target via ignore_index. No
# ``.any().item()`` sync — Liger returns 0.0 when every position
# is ignored, preserving graph capture for CUDA graphs.
shift_targets = torch.where(
shift_valid, shift_targets, torch.full_like(shift_targets, -100)
)
B, T_1, H = shift_hidden.shape
flat_hidden = shift_hidden.reshape(B * T_1, H).to(lm_head_weight.dtype)
flat_labels = shift_targets.reshape(B * T_1)
loss_fn = LigerFusedLinearCrossEntropyLoss(
ignore_index=-100,
reduction="mean",
)
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
# ----------------------------------------------------------------------
@@ -726,9 +760,12 @@ class PI052Policy(PI05Policy):
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
else:
text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
text_loss = _shifted_ce(
text_logits,
# Liger fused linear-CE: skip the explicit ``lm_head(...)``
# materialisation; the kernel multiplies on-the-fly and
# never holds the full (B, T, 257k) logits tensor.
text_loss = _shifted_lin_ce(
text_hidden,
lm_head.weight,
text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
)
@@ -736,8 +773,13 @@ class PI052Policy(PI05Policy):
fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
fast_loss = _fast_lin_ce(
fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return flow_loss, text_loss, fast_loss
@@ -830,9 +872,9 @@ class PI052Policy(PI05Policy):
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
else:
text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
text_loss = _shifted_ce(
text_logits,
text_loss = _shifted_lin_ce(
text_hidden,
lm_head.weight,
text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
)
@@ -844,8 +886,13 @@ class PI052Policy(PI05Policy):
and fast_len > 0
):
fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
fast_loss = _fast_lin_ce(
fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return text_loss, fast_loss