feat(pi052): FAST action CE loss + knowledge insulation + processor wiring

Three additions ported from ``pi05_full`` on branch ``feat/add-pi05``,
giving pi052 full paper-§III.B-C training capabilities alongside the
recipe-driven text supervision it already had:

* **Config flags** in PI052Config:
    - ``enable_fast_action_loss``  default False
    - ``action_tokenizer_name``    default "physical-intelligence/fast"
    - ``max_action_tokens``        default 256
    - ``fast_skip_tokens``         default 128
    - ``fast_action_loss_weight``  default 1.0
    - ``knowledge_insulation``     default False

* **Processor wiring** (processor_pi052.py): when
  ``enable_fast_action_loss=True``, append an
  ``ActionTokenizerProcessorStep`` after the text tokenizer. It
  tokenises the action tensor with the FAST tokenizer and writes
  ACTION_TOKENS / ACTION_TOKEN_MASK into ``COMPLEMENTARY_DATA`` —
  the existing batch-collation pipeline forwards them as
  ``batch['action.tokens']`` / ``batch['action.token_mask']``.

* **FAST CE loss** (modeling_pi052.py::_compute_fast_action_loss):
  Re-embeds the prefix [images, language], appends the FAST token
  embeddings (using PaliGemma's shared embed_language_tokens),
  forwards through the backbone, slices the trailing
  ``fast_len`` positions, applies the LM head, computes shifted
  next-token CE with the action-mask gating the loss. The loss is
  summed into ``forward()``'s total with ``fast_action_loss_weight``.

* **Knowledge insulation** (modeling_pi052.py::_compute_layer_ki +
  _paligemma_forward_ki): port of pi05_full's per-layer attention
  that detaches VLM K/V on the action-query path so action loss
  gradients cannot flow back into the VLM's K/V projections. Bound
  per-instance via ``types.MethodType`` so it doesn't leak into
  stock ``pi05`` policies that share PaliGemmaWithExpertModel.
  Activated automatically when ``config.knowledge_insulation=True``.

Combined with the existing recipe-driven text head, pi052 now
supports the full three-loss objective:

   L = text_w·H(text) + fast_w·H(FAST actions) + flow_w·MSE(flow)

matching Eq. (1) of arxiv:2504.16054 §IV.D (α=10 by default for the
flow term, 1.0 each for text and FAST CE).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 11:46:21 +02:00
parent 8eba704f15
commit 8dc0af3c28
3 changed files with 354 additions and 1 deletions
@@ -99,6 +99,50 @@ class PI052Config(PI05Config):
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
# FAST discrete-action supervision — paper §III.B-C ------------------
# When enabled, actions are *also* tokenised via the FAST tokenizer
# ("physical-intelligence/fast") and supervised with cross-entropy
# on the PaliGemma LM head — exactly as in the paper's pre-training
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = False
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. Off by default because most
fine-tuning runs only need the flow head + text supervision; the
FAST CE term is most useful when training from a base PaliGemma
rather than an existing π0.5 checkpoint."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
max_action_tokens: int = 256
"""Maximum number of FAST tokens per action chunk."""
fast_skip_tokens: int = 128
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
collisions with PaliGemma's text vocabulary."""
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# *blocked* from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implementation requires a custom
# per-layer attention forward that uses ``.detach()`` on VLM K/V
# when computing attention for action queries — see
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
# on branch ``feat/add-pi05`` for the reference implementation.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path (blocks action→VLM gradient flow on K/V).
Currently a no-op stub in this branch — PR follow-up required to
port the full custom layer from pi05_full. The flag is exposed
here so SLURM commands can be written against the final shape."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
@@ -39,10 +39,12 @@ from __future__ import annotations
import logging
import math
import types
from typing import Any
import torch
from torch import Tensor
from torch.nn import functional as F
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
@@ -53,6 +55,190 @@ from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
# ----------------------------------------------------------------------
#
# Per-layer attention that splits the queries into VLM and action
# parts, computing attention for action queries with .detach()'d VLM
# K/V so the action loss's gradient cannot flow back into the VLM's K
# and V projections. Forward output is bit-equivalent to the standard
# layer; backward differs only on the path action_loss → VLM K/V.
def _compute_layer_ki(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma,
gemma_expert,
):
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
models = [paligemma.language_model, gemma_expert.model]
query_states, key_states, value_states, gates = [], [], [], []
vlm_len = inputs_embeds[0].shape[1]
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
q = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
k = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
v = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(q)
key_states.append(k)
value_states.append(v)
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Split queries / K / V at the VLM-vs-action boundary.
Q_vlm = query_states[:, :, :vlm_len, :]
Q_action = query_states[:, :, vlm_len:, :]
K_vlm = key_states[:, :, :vlm_len, :]
K_action = key_states[:, :, vlm_len:, :]
V_vlm = value_states[:, :, :vlm_len, :]
V_action = value_states[:, :, vlm_len:, :]
# Detach VLM K/V *only* on the path the action queries use.
K_vlm_det = K_vlm.detach()
V_vlm_det = V_vlm.detach()
K_for_vlm = key_states # full (gradients flow)
V_for_vlm = value_states
K_for_action = torch.cat([K_vlm_det, K_action], dim=2)
V_for_action = torch.cat([V_vlm_det, V_action], dim=2)
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
mask_for_action = attention_mask[:, :, vlm_len:, :]
att_vlm, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
)
att_action, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
)
att = torch.cat([att_vlm, att_action], dim=1)
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end = start + hidden_states.shape[1]
if att.dtype != layer.self_attn.o_proj.weight.dtype:
att = att.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att[:, start:end])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start = end
return outputs_embeds
def _paligemma_forward_ki(
self,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
adarms_cond=None,
fill_kv_cache=None,
):
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
dual-expert layer pass through :func:`_compute_layer_ki`.
Bound onto the model instance when ``config.knowledge_insulation``
is True (see ``PI052Policy.__init__``). Single-expert branches
(VLM-only or action-only) reuse the parent's implementation
because there's no KI signal to add — KI only matters when
actions and VLM tokens are forwarded together.
"""
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
if adarms_cond is None:
adarms_cond = [None, None]
# Single-expert paths: defer to the bound class method via super().
if inputs_embeds[0] is None or inputs_embeds[1] is None:
return type(self).__bases__[0].forward(
self,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
adarms_cond=adarms_cond,
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
adarms_cond=adarms_cond,
)
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
use_gc = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (
hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training
)
for layer_idx in range(num_layers):
if use_gc:
inputs_embeds = torch.utils.checkpoint.checkpoint(
_compute_layer_ki,
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
use_reentrant=False, preserve_rng_state=False,
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = _compute_layer_ki(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
)
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return [outputs_embeds[0], outputs_embeds[1]], None
class PI052Policy(PI05Policy):
"""π0.5 with the PaliGemma LM head re-enabled."""
@@ -68,6 +254,20 @@ class PI052Policy(PI05Policy):
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
self._unfreeze_lm_head()
# Knowledge insulation: bind a custom ``forward`` on the
# PaliGemmaWithExpertModel instance that uses
# :func:`_compute_layer_ki` for the dual-expert layer pass.
# The bind is per-instance, so this doesn't leak into stock
# ``pi05`` policies that share the same class.
if getattr(config, "knowledge_insulation", False):
backbone = self.model.paligemma_with_expert
backbone._pi052_orig_forward = backbone.forward
backbone.forward = types.MethodType(_paligemma_forward_ki, backbone)
logger.info(
"PI052: knowledge insulation enabled — action→VLM K/V "
"gradients are blocked in attention."
)
# ------------------------------------------------------------------
# Head unfreeze helper
# ------------------------------------------------------------------
@@ -143,6 +343,22 @@ class PI052Policy(PI05Policy):
else total + self.config.text_loss_weight * text_loss
)
# FAST action-token CE loss (paper §III.C). When
# ``enable_fast_action_loss=True`` the preprocessor wrote
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
# forward them through the PaliGemma backbone alongside the
# language prefix and compute CE on the action positions.
if getattr(self.config, "enable_fast_action_loss", False):
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
action_tokens = batch.get(ACTION_TOKENS)
action_mask = batch.get(ACTION_TOKEN_MASK)
if action_tokens is not None and action_mask is not None:
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
weighted = self.config.fast_action_loss_weight * fast_loss
total = weighted if total is None else total + weighted
if total is None:
# Both flow and text disabled — make this an obvious bug
# rather than a silent zero loss.
@@ -161,6 +377,82 @@ class PI052Policy(PI05Policy):
# Text loss
# ------------------------------------------------------------------
def _compute_fast_action_loss(
self,
batch: dict[str, Tensor],
action_tokens: Tensor,
action_mask: Tensor,
) -> Tensor:
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
Mirrors the paper's §III.C action-token loss: append the FAST
tokens to the language prefix, forward through the backbone,
slice the per-token logits at the action positions, and
compute CE against the FAST targets (shifted for next-token
prediction). Token-mask gates the loss to valid positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
# Re-embed prefix to get [images, language] embeddings, then
# append the FAST action token embeddings as additional
# prefix tokens. PaliGemma's embed_language_tokens is shared
# between text and FAST tokens so we can re-use it directly.
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
emb_dim = prefix_embs.shape[-1]
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
# Concat onto the prefix. Pad masks: language uses
# ``lang_masks``; FAST uses ``action_mask`` (True at valid
# token positions). Attention masks add ``True`` (causal)
# for FAST so they can attend to the bidirectional prefix
# but only causally among themselves.
bsize, fast_len = action_tokens.shape
device = prefix_embs.device
ones_att = torch.ones((bsize, fast_len), dtype=torch.bool, device=device)
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
full_att = torch.cat([prefix_att, ones_att], dim=1)
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[full_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
# Slice the last ``fast_len`` positions — those correspond to
# the FAST tokens we just appended.
fast_hidden = vlm_out[:, -fast_len:, :]
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
# Shift for next-token prediction.
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous()
shift_mask = action_mask[:, 1:].contiguous().float()
per_tok = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1),
reduction="none",
).view(shift_targets.shape)
loss = (per_tok * shift_mask).sum() / shift_mask.sum().clamp_min(1.0)
return loss
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on PaliGemma's LM head over the supervised span.
+18 -1
View File
@@ -43,6 +43,7 @@ import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
@@ -101,9 +102,25 @@ def make_pi052_pre_post_processors(
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
DeviceProcessorStep(device=config.device),
]
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
# Only inserted when explicitly enabled — keeps the post-training-
# style recipe (flow + text) as the default. When on, the step
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
if getattr(config, "enable_fast_action_loss", False):
input_steps.append(
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,