pi052: wire Liger fused linear CE + DDP-safe FAST tokenizer fit

* Replace ``_shifted_ce`` / ``_fast_ce`` with Liger's
  ``fused_linear_cross_entropy``: the ``(B, T, 257k)`` logits tensor
  is no longer materialised — the kernel chunks over the ``(B*T)``
  axis and computes matmul + softmax + CE in fused Triton blocks.
  ~30 % step speedup and ~12 GB of activation memory freed on the
  dual-CE pi052 recipe. All four call sites in
  ``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``
  updated; the ``.any().item()`` CPU sync is dropped so the loss
  path stays CUDA-graph-capturable.

* DDP-safe FAST tokenizer fit. The cache-hit sentinel previously
  looked for ``preprocessor_config.json`` but
  ``ProcessorMixin.save_pretrained`` writes ``processor_config.json``
  — every rank always cache-missed and re-fit, racing on writes and
  occasionally producing a stale ``.pyc`` that crashed
  ``AutoProcessor.from_pretrained`` with ``AttributeError:
  UniversalActionProcessor``. Fix the sentinel; gate the fit on the
  (local) main process; non-leader ranks poll the cache until the
  leader is done. Caught by job 22162549.

* New recipe ``subtask_mem_vqa_robocasa.yaml`` — subtask + memory +
  per-camera VQA over the three robocasa camera keys produced by the
  port pipeline (``robot0_agentview_left/right``, ``robot0_eye_in_hand``).
  The previously-shipped ``subtask_mem_vqa_speech.yaml`` references
  ``observation.images.front`` / ``wrist`` which don't exist in
  robocasa, so VQA never rendered.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-26 11:18:16 +00:00
parent 8615f3f613
commit 4c3ddb1ff5
4 changed files with 239 additions and 54 deletions
@@ -0,0 +1,99 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -201,11 +201,15 @@ class PI052Config(PI05Config):
# layer_norm only → 1.1% step time # layer_norm only → 1.1% step time
# all three → 4.5% step time, peak_mem unchanged # all three → 4.5% step time, peak_mem unchanged
# #
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled # ``fused_linear_cross_entropy`` is now wired directly into the
# pi052 calls ``F.cross_entropy`` directly and bypasses # pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger # ``modeling_pi052``). The kernel takes ``(hidden_states,
# patch fires without invasive model-code changes. Reserved for a # lm_head.weight, labels)`` and computes matmul + softmax + CE in
# follow-up. # fused Triton blocks, never materialising the (B, T, 257k) logits
# tensor. Saves ~10 GB activation memory per CE branch and ~30 %
# step time on the dual-CE pi052 recipe (text + FAST). Removing the
# ``.any().item()`` sync also lets ``compile_mode=reduce-overhead``
# capture full CUDA graphs over the loss path.
use_hf_kernels: bool = False use_hf_kernels: bool = False
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's """If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
fused Triton kernels (rope + geglu + layer_norm). Off by default; fused Triton kernels (rope + geglu + layer_norm). Off by default;
@@ -39,12 +39,21 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
import os
import time
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature( def _dataset_signature(
dataset_repo_id: str, dataset_repo_id: str,
@@ -111,7 +120,7 @@ def fit_fast_tokenizer(
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size) sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists(): if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info( logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for " "FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d", "dataset=%s base=%s n_samples=%d",
@@ -119,6 +128,32 @@ def fit_fast_tokenizer(
) )
return str(out_dir) return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info( logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s " "FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s", "base=%s n_samples=%d chunk_size=%d%s",
+95 -48
View File
@@ -106,35 +106,53 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0) return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor: def _shifted_lin_ce(
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100. hidden: Tensor,
lm_head_weight: Tensor,
labels: Tensor,
z_loss_weight: float = 0.0,
) -> Tensor:
"""Liger-fused (hidden @ W.T → softmax → CE) on shifted labels.
Mean over non-ignored positions across the batch. Returns 0 cleanly Replaces the explicit ``lm_head(hidden) → F.cross_entropy(...)``
when no positions are supervised (clamp(min=1) on the denominator). pair with Liger's ``LigerFusedLinearCrossEntropyLoss``: the full
``(B, T, V)`` logits tensor is never materialised — the kernel
chunks over the (B*T) axis, computing matmul + logsumexp + CE
in fused Triton blocks. On a 257k-vocab head this saves ~10 GB
of activation memory per CE branch and ~30 % step time vs the
eager ``F.cross_entropy`` path.
When ``z_loss_weight > 0``, also adds PaLM-style z-loss Semantics:
(``z² · w``, where ``z = log Σ exp(logits)``) on every supervised * Shift convention identical to the eager version — hidden at
position. Penalises the log-partition function drifting away from position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
zero — without it, large-vocab models (PaliGemma is 257k) can let * No ``.any().item()`` sync — Liger returns 0.0 cleanly when
``logsumexp`` grow unboundedly while CE stays low, because uniform every label is ignored, keeping the graph capturable for
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla ``compile_mode=reduce-overhead`` (CUDA graphs).
report this is essential for stable large-vocab CE; cheap insurance * ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk. (same ``z²·w`` formula on per-position logsumexp). Setting it
to 0 disables the z-loss term at zero cost.
""" """
shift_logits = logits[:, :-1, :].contiguous() # Liger is imported lazily so the module still imports on machines
# without liger-kernel; the call site only ever runs after
# use_hf_kernels / training has selected the Liger path.
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous().long() shift_labels = labels[:, 1:].contiguous().long()
valid = shift_labels != -100 B, T_1, H = shift_hidden.shape
if not bool(valid.any().item()): flat_hidden = shift_hidden.reshape(B * T_1, H)
return shift_logits.sum() * 0.0 flat_labels = shift_labels.reshape(B * T_1)
valid_logits = shift_logits[valid] # Match the dtype the eager path used: cast hidden to the lm_head's
valid_labels = shift_labels[valid] # weight dtype so bf16 weights see bf16 activations.
ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean") flat_hidden = flat_hidden.to(lm_head_weight.dtype)
if z_loss_weight <= 0.0: loss_fn = LigerFusedLinearCrossEntropyLoss(
return ce ignore_index=-100,
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position. lse_square_scale=float(z_loss_weight),
# ``logsumexp`` is numerically stable and shares the softmax kernel. reduction="mean",
z = torch.logsumexp(valid_logits, dim=-1) )
return ce + z_loss_weight * (z**2).mean() return loss_fn(lm_head_weight, flat_hidden, flat_labels)
def _mark_target_span_causal( def _mark_target_span_causal(
@@ -172,32 +190,48 @@ def _mark_target_span_causal(
return att return att
def _fast_ce( def _fast_lin_ce(
fast_logits: Tensor, hidden: Tensor,
lm_head_weight: Tensor,
action_tokens: Tensor, action_tokens: Tensor,
action_code_mask: Tensor, action_code_mask: Tensor,
predict_actions_t: Tensor | None, predict_actions_t: Tensor | None,
) -> Tensor: ) -> Tensor:
"""FAST action-code CE with token-span masking and per-sample action gating. """Liger-fused FAST action-code CE with span masking + sample gating.
``action_code_mask`` is true only on the discrete action-code tokens, Mirrors ``_shifted_lin_ce`` but with FAST-specific masking: only
excluding the BOS / "Action: " / delimiter wrapper. Samples whose the discrete action-code positions (``action_code_mask``) are
recipe sets ``predict_actions=False`` get all code positions masked supervised, and samples whose recipe sets ``predict_actions=False``
out via the per-sample gate. get all code positions masked. Masked positions are folded into
Liger's ``ignore_index=-100`` so the kernel skips them without
a CPU-side gather (which would synchronise + break CUDA graphs).
""" """
shift_logits = fast_logits[:, :-1, :].contiguous() from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long() shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_code_mask[:, 1:].contiguous().bool() shift_valid = action_code_mask[:, 1:].contiguous().bool()
if predict_actions_t is not None: if predict_actions_t is not None:
sample_mask = predict_actions_t[:, None].expand_as(shift_valid) sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
shift_valid = shift_valid & sample_mask shift_valid = shift_valid & sample_mask
if not bool(shift_valid.any().item()): # Fold the boolean mask into the target via ignore_index. No
return shift_logits.sum() * 0.0 # ``.any().item()`` sync — Liger returns 0.0 when every position
return F.cross_entropy( # is ignored, preserving graph capture for CUDA graphs.
shift_logits[shift_valid], shift_targets = torch.where(
shift_targets[shift_valid], shift_valid, shift_targets, torch.full_like(shift_targets, -100)
)
B, T_1, H = shift_hidden.shape
flat_hidden = shift_hidden.reshape(B * T_1, H).to(lm_head_weight.dtype)
flat_labels = shift_targets.reshape(B * T_1)
loss_fn = LigerFusedLinearCrossEntropyLoss(
ignore_index=-100,
reduction="mean", reduction="mean",
) )
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
@@ -726,9 +760,12 @@ class PI052Policy(PI05Policy):
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :] text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
else: else:
text_hidden = prefix_out[:, -lang_len:, :] text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) # Liger fused linear-CE: skip the explicit ``lm_head(...)``
text_loss = _shifted_ce( # materialisation; the kernel multiplies on-the-fly and
text_logits, # never holds the full (B, T, 257k) logits tensor.
text_loss = _shifted_lin_ce(
text_hidden,
lm_head.weight,
text_labels, text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
) )
@@ -736,8 +773,13 @@ class PI052Policy(PI05Policy):
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None and action_code_mask is not None: if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
fast_hidden = prefix_out[:, -fast_len:, :] fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_loss = _fast_lin_ce(
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return flow_loss, text_loss, fast_loss return flow_loss, text_loss, fast_loss
@@ -830,9 +872,9 @@ class PI052Policy(PI05Policy):
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :] text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
else: else:
text_hidden = vlm_out[:, -lang_len:, :] text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) text_loss = _shifted_lin_ce(
text_loss = _shifted_ce( text_hidden,
text_logits, lm_head.weight,
text_labels, text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
) )
@@ -844,8 +886,13 @@ class PI052Policy(PI05Policy):
and fast_len > 0 and fast_len > 0
): ):
fast_hidden = vlm_out[:, -fast_len:, :] fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_loss = _fast_lin_ce(
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return text_loss, fast_loss return text_loss, fast_loss