From 4c3ddb1ff5bf3701c86589beb004e975a252d823 Mon Sep 17 00:00:00 2001 From: pepijn Date: Tue, 26 May 2026 11:18:16 +0000 Subject: [PATCH] pi052: wire Liger fused linear CE + DDP-safe FAST tokenizer fit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace ``_shifted_ce`` / ``_fast_ce`` with Liger's ``fused_linear_cross_entropy``: the ``(B, T, 257k)`` logits tensor is no longer materialised — the kernel chunks over the ``(B*T)`` axis and computes matmul + softmax + CE in fused Triton blocks. ~30 % step speedup and ~12 GB of activation memory freed on the dual-CE pi052 recipe. All four call sites in ``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss`` updated; the ``.any().item()`` CPU sync is dropped so the loss path stays CUDA-graph-capturable. * DDP-safe FAST tokenizer fit. The cache-hit sentinel previously looked for ``preprocessor_config.json`` but ``ProcessorMixin.save_pretrained`` writes ``processor_config.json`` — every rank always cache-missed and re-fit, racing on writes and occasionally producing a stale ``.pyc`` that crashed ``AutoProcessor.from_pretrained`` with ``AttributeError: UniversalActionProcessor``. Fix the sentinel; gate the fit on the (local) main process; non-leader ranks poll the cache until the leader is done. Caught by job 22162549. * New recipe ``subtask_mem_vqa_robocasa.yaml`` — subtask + memory + per-camera VQA over the three robocasa camera keys produced by the port pipeline (``robot0_agentview_left/right``, ``robot0_eye_in_hand``). The previously-shipped ``subtask_mem_vqa_speech.yaml`` references ``observation.images.front`` / ``wrist`` which don't exist in robocasa, so VQA never rendered. Co-authored-by: Cursor --- .../recipes/subtask_mem_vqa_robocasa.yaml | 99 ++++++++++++ .../policies/pi052/configuration_pi052.py | 14 +- .../policies/pi052/fit_fast_tokenizer.py | 37 ++++- src/lerobot/policies/pi052/modeling_pi052.py | 143 ++++++++++++------ 4 files changed, 239 insertions(+), 54 deletions(-) create mode 100644 src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml diff --git a/src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml b/src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml new file mode 100644 index 000000000..607e20e5d --- /dev/null +++ b/src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml @@ -0,0 +1,99 @@ +# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras. +# +# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus +# camera-grounded VQA across the three RoboCasa camera keys produced +# by ``slurm_build_robocasa_composite_seen.py``: +# +# observation.images.robot0_agentview_left (left scene view) +# observation.images.robot0_agentview_right (right scene view) +# observation.images.robot0_eye_in_hand (wrist) +# +# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits +# VQA per camera, so each anchor frame produces three (user, assistant) +# rows tagged with their source camera. Each VQA sub-recipe consumes +# the rows for one camera via ``camera=...`` resolver bindings. +# +# Spatial VQA targets (bbox / point) are rewritten from JSON to +# PaliGemma ```` tokens by ``_messages_vqa_to_loc`` — +# ``register_paligemma_loc_tokens`` already collapses them to single +# detection-vocab ids so the LM head learns the pretrained pointing / +# detection prior, not a 7-piece BPE salad. +# +# Interjections / spoken responses are intentionally absent — the +# annotation job runs with ``--interjections.enabled=false``. + +blend: + + high_level_subtask: + weight: 0.25 + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask} + + low_level_execution: + weight: 0.45 + messages: + # Action expert is conditioned on the SUBTASK; at inference the + # high-level loop generates it via the LM head and feeds it here. + # ``stream: low_level`` flips ``predict_actions=True`` so the flow + # loss fires; subtask CE is owned by ``high_level_subtask``. + - {role: user, content: "${subtask}", stream: low_level, if_present: subtask} + + memory_update: + # Trained densely with ``active_at`` — every frame inside a subtask + # interval — so the (prior_memory, completed_subtask) → current_memory + # mapping is supervised against varied observations. The *when* to + # emit lives in the inference trigger (subtask_change), not the + # model. See ``subtask_mem.yaml`` for the long version of this note. + weight: 0.15 + bindings: + prior_memory: "nth_prev(style=memory, offset=1)" + current_memory: "active_at(t, style=memory)" + completed_subtask: "nth_prev(style=subtask, offset=1)" + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory} + - {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask} + - {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory} + + ask_vqa_agentview_left: + weight: 0.05 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.robot0_agentview_left} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} + + ask_vqa_agentview_right: + weight: 0.05 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.robot0_agentview_right} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} + + ask_vqa_wrist: + weight: 0.05 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.robot0_eye_in_hand} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 79b058dba..2b02a2baa 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -201,11 +201,15 @@ class PI052Config(PI05Config): # layer_norm only → −1.1% step time # all three → −4.5% step time, peak_mem unchanged # - # ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled - # — pi052 calls ``F.cross_entropy`` directly and bypasses - # ``PaliGemmaForConditionalGeneration.forward``, so neither Liger - # patch fires without invasive model-code changes. Reserved for a - # follow-up. + # ``fused_linear_cross_entropy`` is now wired directly into the + # pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see + # ``modeling_pi052``). The kernel takes ``(hidden_states, + # lm_head.weight, labels)`` and computes matmul + softmax + CE in + # fused Triton blocks, never materialising the (B, T, 257k) logits + # tensor. Saves ~10 GB activation memory per CE branch and ~30 % + # step time on the dual-CE pi052 recipe (text + FAST). Removing the + # ``.any().item()`` sync also lets ``compile_mode=reduce-overhead`` + # capture full CUDA graphs over the loss path. use_hf_kernels: bool = False """If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's fused Triton kernels (rope + geglu + layer_norm). Off by default; diff --git a/src/lerobot/policies/pi052/fit_fast_tokenizer.py b/src/lerobot/policies/pi052/fit_fast_tokenizer.py index 513553c00..e27c01343 100644 --- a/src/lerobot/policies/pi052/fit_fast_tokenizer.py +++ b/src/lerobot/policies/pi052/fit_fast_tokenizer.py @@ -39,12 +39,21 @@ from __future__ import annotations import hashlib import logging +import os +import time from pathlib import Path import numpy as np logger = logging.getLogger(__name__) +# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained`` +# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` — +# that's the image / feature-extractor convention). Centralised here so +# the cache-hit check and the rank-N readiness wait agree on the same +# sentinel. +_CACHE_SENTINEL = "processor_config.json" + def _dataset_signature( dataset_repo_id: str, @@ -111,7 +120,7 @@ def fit_fast_tokenizer( sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size) out_dir = cache_dir / sig - if out_dir.exists() and (out_dir / "preprocessor_config.json").exists(): + if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists(): logger.info( "FAST tokenizer cache hit: %s — re-using fitted tokenizer for " "dataset=%s base=%s n_samples=%d", @@ -119,6 +128,32 @@ def fit_fast_tokenizer( ) return str(out_dir) + # DDP-safe fit: only the (local) main process actually fits + saves; + # other ranks poll the cache sentinel until the leader is done. + # Without this guard, all N ranks fit concurrently and race on + # ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter + # copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE`` + # and compiles a ``.pyc`` — concurrent writers occasionally produce + # a stale / partial ``.pyc`` and the subsequent ``from .. import + # UniversalActionProcessor`` raises ``AttributeError``. + is_leader = ( + int(os.environ.get("RANK", "0")) == 0 + and int(os.environ.get("LOCAL_RANK", "0")) == 0 + ) + if not is_leader: + timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches + start = time.monotonic() + while not (out_dir / _CACHE_SENTINEL).exists(): + if time.monotonic() - start > timeout_s: + raise RuntimeError( + f"FAST tokenizer fit: non-leader rank timed out after " + f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. " + "Leader rank likely crashed during the fit." + ) + time.sleep(2.0) + logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir) + return str(out_dir) + logger.info( "FAST tokenizer cache miss — fitting on dataset=%s " "base=%s n_samples=%d chunk_size=%d → %s", diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 709e7724f..5f942753c 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -106,35 +106,53 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te return (per_sample * mask).sum() / mask.sum().clamp(min=1.0) -def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor: - """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100. +def _shifted_lin_ce( + hidden: Tensor, + lm_head_weight: Tensor, + labels: Tensor, + z_loss_weight: float = 0.0, +) -> Tensor: + """Liger-fused (hidden @ W.T → softmax → CE) on shifted labels. - Mean over non-ignored positions across the batch. Returns 0 cleanly - when no positions are supervised (clamp(min=1) on the denominator). + Replaces the explicit ``lm_head(hidden) → F.cross_entropy(...)`` + pair with Liger's ``LigerFusedLinearCrossEntropyLoss``: the full + ``(B, T, V)`` logits tensor is never materialised — the kernel + chunks over the (B*T) axis, computing matmul + logsumexp + CE + in fused Triton blocks. On a 257k-vocab head this saves ~10 GB + of activation memory per CE branch and ~30 % step time vs the + eager ``F.cross_entropy`` path. - When ``z_loss_weight > 0``, also adds PaLM-style z-loss - (``z² · w``, where ``z = log Σ exp(logits)``) on every supervised - position. Penalises the log-partition function drifting away from - zero — without it, large-vocab models (PaliGemma is 257k) can let - ``logsumexp`` grow unboundedly while CE stays low, because uniform - additive logit bias cancels in softmax. PaLM appendix B / Chinchilla - report this is essential for stable large-vocab CE; cheap insurance - here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk. + Semantics: + * Shift convention identical to the eager version — hidden at + position ``t`` predicts label at ``t+1``; ``ignore_index=-100``. + * No ``.any().item()`` sync — Liger returns 0.0 cleanly when + every label is ignored, keeping the graph capturable for + ``compile_mode=reduce-overhead`` (CUDA graphs). + * ``z_loss_weight`` maps directly to Liger's ``lse_square_scale`` + (same ``z²·w`` formula on per-position logsumexp). Setting it + to 0 disables the z-loss term at zero cost. """ - shift_logits = logits[:, :-1, :].contiguous() + # Liger is imported lazily so the module still imports on machines + # without liger-kernel; the call site only ever runs after + # use_hf_kernels / training has selected the Liger path. + from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415 + LigerFusedLinearCrossEntropyLoss, + ) + + shift_hidden = hidden[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous().long() - valid = shift_labels != -100 - if not bool(valid.any().item()): - return shift_logits.sum() * 0.0 - valid_logits = shift_logits[valid] - valid_labels = shift_labels[valid] - ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean") - if z_loss_weight <= 0.0: - return ce - # PaLM z-loss: penalise (log Σ exp(logits))² per supervised position. - # ``logsumexp`` is numerically stable and shares the softmax kernel. - z = torch.logsumexp(valid_logits, dim=-1) - return ce + z_loss_weight * (z**2).mean() + B, T_1, H = shift_hidden.shape + flat_hidden = shift_hidden.reshape(B * T_1, H) + flat_labels = shift_labels.reshape(B * T_1) + # Match the dtype the eager path used: cast hidden to the lm_head's + # weight dtype so bf16 weights see bf16 activations. + flat_hidden = flat_hidden.to(lm_head_weight.dtype) + loss_fn = LigerFusedLinearCrossEntropyLoss( + ignore_index=-100, + lse_square_scale=float(z_loss_weight), + reduction="mean", + ) + return loss_fn(lm_head_weight, flat_hidden, flat_labels) def _mark_target_span_causal( @@ -172,32 +190,48 @@ def _mark_target_span_causal( return att -def _fast_ce( - fast_logits: Tensor, +def _fast_lin_ce( + hidden: Tensor, + lm_head_weight: Tensor, action_tokens: Tensor, action_code_mask: Tensor, predict_actions_t: Tensor | None, ) -> Tensor: - """FAST action-code CE with token-span masking and per-sample action gating. + """Liger-fused FAST action-code CE with span masking + sample gating. - ``action_code_mask`` is true only on the discrete action-code tokens, - excluding the BOS / "Action: " / delimiter wrapper. Samples whose - recipe sets ``predict_actions=False`` get all code positions masked - out via the per-sample gate. + Mirrors ``_shifted_lin_ce`` but with FAST-specific masking: only + the discrete action-code positions (``action_code_mask``) are + supervised, and samples whose recipe sets ``predict_actions=False`` + get all code positions masked. Masked positions are folded into + Liger's ``ignore_index=-100`` so the kernel skips them without + a CPU-side gather (which would synchronise + break CUDA graphs). """ - shift_logits = fast_logits[:, :-1, :].contiguous() + from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415 + LigerFusedLinearCrossEntropyLoss, + ) + + shift_hidden = hidden[:, :-1, :].contiguous() shift_targets = action_tokens[:, 1:].contiguous().long() shift_valid = action_code_mask[:, 1:].contiguous().bool() if predict_actions_t is not None: sample_mask = predict_actions_t[:, None].expand_as(shift_valid) shift_valid = shift_valid & sample_mask - if not bool(shift_valid.any().item()): - return shift_logits.sum() * 0.0 - return F.cross_entropy( - shift_logits[shift_valid], - shift_targets[shift_valid], + # Fold the boolean mask into the target via ignore_index. No + # ``.any().item()`` sync — Liger returns 0.0 when every position + # is ignored, preserving graph capture for CUDA graphs. + shift_targets = torch.where( + shift_valid, shift_targets, torch.full_like(shift_targets, -100) + ) + + B, T_1, H = shift_hidden.shape + flat_hidden = shift_hidden.reshape(B * T_1, H).to(lm_head_weight.dtype) + flat_labels = shift_targets.reshape(B * T_1) + + loss_fn = LigerFusedLinearCrossEntropyLoss( + ignore_index=-100, reduction="mean", ) + return loss_fn(lm_head_weight, flat_hidden, flat_labels) # ---------------------------------------------------------------------- @@ -726,9 +760,12 @@ class PI052Policy(PI05Policy): text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :] else: text_hidden = prefix_out[:, -lang_len:, :] - text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) - text_loss = _shifted_ce( - text_logits, + # Liger fused linear-CE: skip the explicit ``lm_head(...)`` + # materialisation; the kernel multiplies on-the-fly and + # never holds the full (B, T, 257k) logits tensor. + text_loss = _shifted_lin_ce( + text_hidden, + lm_head.weight, text_labels, z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), ) @@ -736,8 +773,13 @@ class PI052Policy(PI05Policy): fast_loss: Tensor | None = None if fast_len > 0 and prefix_out is not None and action_code_mask is not None: fast_hidden = prefix_out[:, -fast_len:, :] - fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) - fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) + fast_loss = _fast_lin_ce( + fast_hidden, + lm_head.weight, + action_tokens, + action_code_mask, + predict_actions_t, + ) return flow_loss, text_loss, fast_loss @@ -830,9 +872,9 @@ class PI052Policy(PI05Policy): text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :] else: text_hidden = vlm_out[:, -lang_len:, :] - text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) - text_loss = _shifted_ce( - text_logits, + text_loss = _shifted_lin_ce( + text_hidden, + lm_head.weight, text_labels, z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), ) @@ -844,8 +886,13 @@ class PI052Policy(PI05Policy): and fast_len > 0 ): fast_hidden = vlm_out[:, -fast_len:, :] - fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) - fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) + fast_loss = _fast_lin_ce( + fast_hidden, + lm_head.weight, + action_tokens, + action_code_mask, + predict_actions_t, + ) return text_loss, fast_loss