feat(pi052): FAST action CE loss + knowledge insulation + processor wiring

Three additions ported from ``pi05_full`` on branch ``feat/add-pi05``,
giving pi052 full paper-§III.B-C training capabilities alongside the
recipe-driven text supervision it already had:

* **Config flags** in PI052Config:
    - ``enable_fast_action_loss``  default False
    - ``action_tokenizer_name``    default "physical-intelligence/fast"
    - ``max_action_tokens``        default 256
    - ``fast_skip_tokens``         default 128
    - ``fast_action_loss_weight``  default 1.0
    - ``knowledge_insulation``     default False

* **Processor wiring** (processor_pi052.py): when
  ``enable_fast_action_loss=True``, append an
  ``ActionTokenizerProcessorStep`` after the text tokenizer. It
  tokenises the action tensor with the FAST tokenizer and writes
  ACTION_TOKENS / ACTION_TOKEN_MASK into ``COMPLEMENTARY_DATA`` —
  the existing batch-collation pipeline forwards them as
  ``batch['action.tokens']`` / ``batch['action.token_mask']``.

* **FAST CE loss** (modeling_pi052.py::_compute_fast_action_loss):
  Re-embeds the prefix [images, language], appends the FAST token
  embeddings (using PaliGemma's shared embed_language_tokens),
  forwards through the backbone, slices the trailing
  ``fast_len`` positions, applies the LM head, computes shifted
  next-token CE with the action-mask gating the loss. The loss is
  summed into ``forward()``'s total with ``fast_action_loss_weight``.

* **Knowledge insulation** (modeling_pi052.py::_compute_layer_ki +
  _paligemma_forward_ki): port of pi05_full's per-layer attention
  that detaches VLM K/V on the action-query path so action loss
  gradients cannot flow back into the VLM's K/V projections. Bound
  per-instance via ``types.MethodType`` so it doesn't leak into
  stock ``pi05`` policies that share PaliGemmaWithExpertModel.
  Activated automatically when ``config.knowledge_insulation=True``.

Combined with the existing recipe-driven text head, pi052 now
supports the full three-loss objective:

   L = text_w·H(text) + fast_w·H(FAST actions) + flow_w·MSE(flow)

matching Eq. (1) of arxiv:2504.16054 §IV.D (α=10 by default for the
flow term, 1.0 each for text and FAST CE).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 11:46:21 +02:00
parent 8eba704f15
commit 8dc0af3c28
3 changed files with 354 additions and 1 deletions
@@ -99,6 +99,50 @@ class PI052Config(PI05Config):
memory_dropout_prob: float = 0.0 memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0 subtask_dropout_prob: float = 0.0
# FAST discrete-action supervision — paper §III.B-C ------------------
# When enabled, actions are *also* tokenised via the FAST tokenizer
# ("physical-intelligence/fast") and supervised with cross-entropy
# on the PaliGemma LM head — exactly as in the paper's pre-training
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = False
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. Off by default because most
fine-tuning runs only need the flow head + text supervision; the
FAST CE term is most useful when training from a base PaliGemma
rather than an existing π0.5 checkpoint."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
max_action_tokens: int = 256
"""Maximum number of FAST tokens per action chunk."""
fast_skip_tokens: int = 128
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
collisions with PaliGemma's text vocabulary."""
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# *blocked* from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implementation requires a custom
# per-layer attention forward that uses ``.detach()`` on VLM K/V
# when computing attention for action queries — see
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
# on branch ``feat/add-pi05`` for the reference implementation.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path (blocks action→VLM gradient flow on K/V).
Currently a no-op stub in this branch — PR follow-up required to
port the full custom layer from pi05_full. The flag is exposed
here so SLURM commands can be written against the final shape."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
# Backbone needs gradients flowing through the text head when # Backbone needs gradients flowing through the text head when
@@ -39,10 +39,12 @@ from __future__ import annotations
import logging import logging
import math import math
import types
from typing import Any from typing import Any
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import functional as F
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
@@ -53,6 +55,190 @@ from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
# ----------------------------------------------------------------------
#
# Per-layer attention that splits the queries into VLM and action
# parts, computing attention for action queries with .detach()'d VLM
# K/V so the action loss's gradient cannot flow back into the VLM's K
# and V projections. Forward output is bit-equivalent to the standard
# layer; backward differs only on the path action_loss → VLM K/V.
def _compute_layer_ki(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma,
gemma_expert,
):
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
models = [paligemma.language_model, gemma_expert.model]
query_states, key_states, value_states, gates = [], [], [], []
vlm_len = inputs_embeds[0].shape[1]
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
q = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
k = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
v = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(q)
key_states.append(k)
value_states.append(v)
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Split queries / K / V at the VLM-vs-action boundary.
Q_vlm = query_states[:, :, :vlm_len, :]
Q_action = query_states[:, :, vlm_len:, :]
K_vlm = key_states[:, :, :vlm_len, :]
K_action = key_states[:, :, vlm_len:, :]
V_vlm = value_states[:, :, :vlm_len, :]
V_action = value_states[:, :, vlm_len:, :]
# Detach VLM K/V *only* on the path the action queries use.
K_vlm_det = K_vlm.detach()
V_vlm_det = V_vlm.detach()
K_for_vlm = key_states # full (gradients flow)
V_for_vlm = value_states
K_for_action = torch.cat([K_vlm_det, K_action], dim=2)
V_for_action = torch.cat([V_vlm_det, V_action], dim=2)
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
mask_for_action = attention_mask[:, :, vlm_len:, :]
att_vlm, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
)
att_action, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
)
att = torch.cat([att_vlm, att_action], dim=1)
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end = start + hidden_states.shape[1]
if att.dtype != layer.self_attn.o_proj.weight.dtype:
att = att.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att[:, start:end])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start = end
return outputs_embeds
def _paligemma_forward_ki(
self,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
adarms_cond=None,
fill_kv_cache=None,
):
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
dual-expert layer pass through :func:`_compute_layer_ki`.
Bound onto the model instance when ``config.knowledge_insulation``
is True (see ``PI052Policy.__init__``). Single-expert branches
(VLM-only or action-only) reuse the parent's implementation
because there's no KI signal to add — KI only matters when
actions and VLM tokens are forwarded together.
"""
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
if adarms_cond is None:
adarms_cond = [None, None]
# Single-expert paths: defer to the bound class method via super().
if inputs_embeds[0] is None or inputs_embeds[1] is None:
return type(self).__bases__[0].forward(
self,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
adarms_cond=adarms_cond,
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
adarms_cond=adarms_cond,
)
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
use_gc = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (
hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training
)
for layer_idx in range(num_layers):
if use_gc:
inputs_embeds = torch.utils.checkpoint.checkpoint(
_compute_layer_ki,
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
use_reentrant=False, preserve_rng_state=False,
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = _compute_layer_ki(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
)
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return [outputs_embeds[0], outputs_embeds[1]], None
class PI052Policy(PI05Policy): class PI052Policy(PI05Policy):
"""π0.5 with the PaliGemma LM head re-enabled.""" """π0.5 with the PaliGemma LM head re-enabled."""
@@ -68,6 +254,20 @@ class PI052Policy(PI05Policy):
if config.text_loss_weight > 0 and config.unfreeze_lm_head: if config.text_loss_weight > 0 and config.unfreeze_lm_head:
self._unfreeze_lm_head() self._unfreeze_lm_head()
# Knowledge insulation: bind a custom ``forward`` on the
# PaliGemmaWithExpertModel instance that uses
# :func:`_compute_layer_ki` for the dual-expert layer pass.
# The bind is per-instance, so this doesn't leak into stock
# ``pi05`` policies that share the same class.
if getattr(config, "knowledge_insulation", False):
backbone = self.model.paligemma_with_expert
backbone._pi052_orig_forward = backbone.forward
backbone.forward = types.MethodType(_paligemma_forward_ki, backbone)
logger.info(
"PI052: knowledge insulation enabled — action→VLM K/V "
"gradients are blocked in attention."
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Head unfreeze helper # Head unfreeze helper
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -143,6 +343,22 @@ class PI052Policy(PI05Policy):
else total + self.config.text_loss_weight * text_loss else total + self.config.text_loss_weight * text_loss
) )
# FAST action-token CE loss (paper §III.C). When
# ``enable_fast_action_loss=True`` the preprocessor wrote
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
# forward them through the PaliGemma backbone alongside the
# language prefix and compute CE on the action positions.
if getattr(self.config, "enable_fast_action_loss", False):
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
action_tokens = batch.get(ACTION_TOKENS)
action_mask = batch.get(ACTION_TOKEN_MASK)
if action_tokens is not None and action_mask is not None:
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
weighted = self.config.fast_action_loss_weight * fast_loss
total = weighted if total is None else total + weighted
if total is None: if total is None:
# Both flow and text disabled — make this an obvious bug # Both flow and text disabled — make this an obvious bug
# rather than a silent zero loss. # rather than a silent zero loss.
@@ -161,6 +377,82 @@ class PI052Policy(PI05Policy):
# Text loss # Text loss
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _compute_fast_action_loss(
self,
batch: dict[str, Tensor],
action_tokens: Tensor,
action_mask: Tensor,
) -> Tensor:
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
Mirrors the paper's §III.C action-token loss: append the FAST
tokens to the language prefix, forward through the backbone,
slice the per-token logits at the action positions, and
compute CE against the FAST targets (shifted for next-token
prediction). Token-mask gates the loss to valid positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
# Re-embed prefix to get [images, language] embeddings, then
# append the FAST action token embeddings as additional
# prefix tokens. PaliGemma's embed_language_tokens is shared
# between text and FAST tokens so we can re-use it directly.
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
emb_dim = prefix_embs.shape[-1]
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
# Concat onto the prefix. Pad masks: language uses
# ``lang_masks``; FAST uses ``action_mask`` (True at valid
# token positions). Attention masks add ``True`` (causal)
# for FAST so they can attend to the bidirectional prefix
# but only causally among themselves.
bsize, fast_len = action_tokens.shape
device = prefix_embs.device
ones_att = torch.ones((bsize, fast_len), dtype=torch.bool, device=device)
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
full_att = torch.cat([prefix_att, ones_att], dim=1)
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[full_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
# Slice the last ``fast_len`` positions — those correspond to
# the FAST tokens we just appended.
fast_hidden = vlm_out[:, -fast_len:, :]
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
# Shift for next-token prediction.
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous()
shift_mask = action_mask[:, 1:].contiguous().float()
per_tok = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1),
reduction="none",
).view(shift_targets.shape)
loss = (per_tok * shift_mask).sum() / shift_mask.sum().clamp_min(1.0)
return loss
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on PaliGemma's LM head over the supervised span. """Cross-entropy on PaliGemma's LM head over the supervised span.
+18 -1
View File
@@ -43,6 +43,7 @@ import torch
from lerobot.configs.recipe import TrainingRecipe from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import ( from lerobot.processor import (
AbsoluteActionsProcessorStep, AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
NormalizerProcessorStep, NormalizerProcessorStep,
@@ -101,9 +102,25 @@ def make_pi052_pre_post_processors(
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0), memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0), subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
), ),
DeviceProcessorStep(device=config.device),
] ]
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
# Only inserted when explicitly enabled — keeps the post-training-
# style recipe (flow + text) as the default. When on, the step
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
if getattr(config, "enable_fast_action_loss", False):
input_steps.append(
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps = [ output_steps = [
UnnormalizerProcessorStep( UnnormalizerProcessorStep(
features=config.output_features, features=config.output_features,