feat(pi052): amortized K_repeat flow + separate backbone/expert LRs

Two π0.5-paper training techniques for pi052:

- flow_num_repeats (default 5): the action expert runs K independent
  noise/timestep draws against a single shared VLM prefix forward (tiled
  as block-diagonal suffix blocks with the FAST tokens masked out),
  amortizing the dominant backbone cost. Per-block flow losses are
  averaged so the backbone gradient stays well-scaled; pairs with
  knowledge_insulation (which additionally detaches the prefix K/V).
  flow_num_repeats=1 recovers the original single-draw combined forward.
- backbone_lr_scale / action_expert_lr_scale: separate LR groups for the
  pretrained PaliGemma backbone vs the from-scratch action expert, on top
  of the existing lm_head_lr_scale. Defaults of 1.0 keep single-LR behaviour.

PiGemmaRMSNorm now accepts per-token adaRMS conditioning so each tiled
block carries its own timestep (2D per-sample cond is unchanged).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-24 20:30:03 +00:00
parent c31f1b0f72
commit ecb945eb4c
3 changed files with 251 additions and 59 deletions
@@ -168,7 +168,7 @@ class PI052Config(PI05Config):
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
# that splits queries into VLM and action halves and ``.detach()``-s
# the VLM K/V tensors used in the action-half's attention.
knowledge_insulation: bool = False
knowledge_insulation: bool = True
"""If True, route every transformer layer through the KI
attention path that blocks action→VLM gradient flow on K/V."""
@@ -189,6 +189,30 @@ class PI052Config(PI05Config):
# the same cosine lambda, so the 5x ratio is preserved across decay.
lm_head_lr_scale: float = 5.0
# Separate LRs for the VLM backbone vs the action expert (paper §III.B).
# The backbone is a pretrained PaliGemma; the action expert is trained
# from scratch, so their initialisation scales differ and a single global
# LR under-trains one of them. These multipliers scale the base
# ``optimizer_lr`` for each group; the cosine scheduler applies the same
# lambda to every group so the ratios hold across decay. ``backbone_lr_scale``
# covers the PaliGemma tower (except the LM head / tied embeddings, which keep
# their own ``lm_head_lr_scale``); ``action_expert_lr_scale`` covers the Gemma
# expert plus the action/time projection heads. Defaults of 1.0 reproduce the
# single-LR behaviour (back-compat with existing checkpoints/configs).
backbone_lr_scale: float = 1.0
action_expert_lr_scale: float = 1.0
# Amortized flow training (paper §III.B, K_repeat). The VLM/backbone forward
# dominates step cost; to extract more learning signal per VLM pass the action
# expert runs ``flow_num_repeats`` denoising targets per sample, each with an
# independent noise + timestep draw, all attending to the single shared VLM
# prefix. The per-repeat flow losses are averaged, so the backbone gradient
# stays well-scaled. Pairs naturally with ``knowledge_insulation`` (which
# additionally detaches the prefix K/V on the action path), the paper's
# setting — but the amortized path is also correct without it. Set to 1 to
# recover the original single-draw combined forward.
flow_num_repeats: int = 5
# PaLM-style z-loss on text CE. Penalises the log-partition function
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
@@ -250,3 +274,5 @@ class PI052Config(PI05Config):
# out of text training via ``text_loss_weight=0``.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
self.train_expert_only = False
if self.flow_num_repeats < 1:
raise ValueError(f"flow_num_repeats must be >= 1, got {self.flow_num_repeats}")
+220 -57
View File
@@ -1257,39 +1257,21 @@ class PI052Policy(PreTrainedPolicy):
action_code_mask: Tensor | None,
predict_actions_t: Tensor | None = None,
) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Full fusion: flow + text + FAST in ONE backbone forward.
"""Flow + text + FAST losses, sharing a single VLM prefix forward.
Builds:
prefix = [images, language, FAST (when provided)]
suffix = [noisy_actions] (action expert via gemma_expert)
Then overrides the unified 2D attention mask to *explicitly*
zero out ``suffix → FAST`` attention. Without this override
the action expert would attend to the discrete FAST tokens
and trivially decode them back to the same continuous
actions it's supposed to predict via flow matching — the
whole training signal collapses.
Both prefix_out and suffix_out are captured from the same
forward. From prefix_out we slice the language and FAST
token positions and compute their CE losses. From suffix_out
we run the existing flow path (action_out_proj → MSE).
Embeds ``prefix = [images, language, FAST (when provided)]`` once, then
computes the flow loss via either a single combined forward
(``flow_num_repeats == 1``) or the amortized K-repeat path
(``> 1``); both keep the discrete FAST tokens invisible to the action
expert. The text/FAST CE losses are sliced from the shared
``prefix_out``.
Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast
can be ``None`` when the caller didn't supply the
corresponding inputs.
"""
from lerobot.utils.constants import ACTION # noqa: PLC0415
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
actions = self.prepare_action(batch)
noise = self.model.sample_noise(actions.shape, actions.device)
time = self.model.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# ---- prefix: images + language + (optional FAST) -------------
images, img_masks = self._preprocess_images(batch)
@@ -1325,6 +1307,52 @@ class PI052Policy(PreTrainedPolicy):
prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
prefix_att = torch.cat([prefix_att, ones_att], dim=1)
# ---- flow: one combined forward, or amortized over K repeats ----
# ``flow_num_repeats == 1`` keeps the single combined [prefix; suffix]
# forward. ``> 1`` runs the VLM prefix once and replays the action
# expert K times against fresh noise/timestep draws, reusing the
# cached prefix KV (paper §III.B). Both return ``prefix_out`` for the
# shared text/FAST CE tail.
num_repeats = int(getattr(self.config, "flow_num_repeats", 1))
if num_repeats > 1:
prefix_out, flow_loss = self._amortized_prefix_and_flow(
actions, prefix_embs, prefix_pad, prefix_att,
non_fast_prefix_len, fast_len, predict_actions_t, num_repeats,
)
else:
prefix_out, flow_loss = self._combined_prefix_and_flow(
actions, prefix_embs, prefix_pad, prefix_att,
non_fast_prefix_len, fast_len, predict_actions_t,
)
text_loss, fast_loss = self._prefix_ce_losses(
prefix_out, text_labels, action_tokens, action_code_mask, fast_len, predict_actions_t
)
return flow_loss, text_loss, fast_loss
def _combined_prefix_and_flow(
self,
actions: Tensor,
prefix_embs: Tensor,
prefix_pad: Tensor,
prefix_att: Tensor,
non_fast_prefix_len: int,
fast_len: int,
predict_actions_t: Tensor | None,
) -> tuple[Tensor, Tensor]:
"""Single combined [prefix; suffix] forward → (prefix_out, flow_loss).
This is the original (``flow_num_repeats == 1``) path: one noise/time
draw, one backbone forward producing both the VLM prefix hidden states
(for text/FAST CE) and the action-expert suffix hidden states (flow)."""
from lerobot.utils.constants import ACTION # noqa: PLC0415
noise = self.model.sample_noise(actions.shape, actions.device)
time = self.model.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# ---- suffix: noisy actions ----------------------------------
suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time)
@@ -1380,8 +1408,132 @@ class PI052Policy(PreTrainedPolicy):
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
return prefix_out, flow_loss
# ---- text + FAST CE from prefix_out ------------------------
def _amortized_prefix_and_flow(
self,
actions: Tensor,
prefix_embs: Tensor,
prefix_pad: Tensor,
prefix_att: Tensor,
non_fast_prefix_len: int,
fast_len: int,
predict_actions_t: Tensor | None,
num_repeats: int,
) -> tuple[Tensor, Tensor]:
"""Amortized flow: one VLM prefix forward, K action-expert replays.
The VLM/backbone forward dominates step cost, so we keep a *single*
combined forward but tile the action suffix into ``num_repeats`` blocks,
each with an independent noise/timestep draw against the same action
chunk (paper §III.B, K_repeat). The blocks attend to the shared prefix
(FAST columns masked, exactly like the combined path) and are
block-diagonal among themselves, so the expensive prefix K/V is computed
once while the cheap action expert runs K times. Knowledge insulation
(``_compute_layer_ki``) detaches the prefix K/V for the action queries,
so this is gradient-equivalent to K independent draws sharing one VLM
forward. Per-block flow losses are averaged.
"""
from lerobot.utils.constants import ACTION # noqa: PLC0415
model = self.model
k = num_repeats
chunk = self.config.chunk_size
batch_size, prefix_len = prefix_pad.shape
first_layer = model.paligemma_with_expert.paligemma.model.language_model.layers[0]
use_bf16 = first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16
if use_bf16:
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# ---- K suffix blocks: independent noise/time draws ----------
suffix_blocks: list[Tensor] = []
adarms_blocks: list[Tensor] = []
u_t_blocks: list[Tensor] = []
suffix_pad = suffix_att = None
for _ in range(k):
noise = model.sample_noise(actions.shape, actions.device)
time = model.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t_blocks.append(noise - actions)
s_embs, suffix_pad, suffix_att, adarms = model.embed_suffix(x_t, time)
if use_bf16:
s_embs = s_embs.to(dtype=torch.bfloat16)
suffix_blocks.append(s_embs)
# adaRMS time conditioning is per-sample; broadcast it across this
# block's chunk tokens so each block carries its own timestep.
adarms_blocks.append(adarms[:, None, :].expand(batch_size, chunk, adarms.shape[-1]))
suffix_embs = torch.cat(suffix_blocks, dim=1) # (B, k*chunk, D)
adarms_cond = torch.cat(adarms_blocks, dim=1) # (B, k*chunk, cond_dim)
# ---- block-diagonal attention over [prefix | block_1..k] ----
# Prefix rows keep their own (causal/text) attention and never see the
# action blocks. Each action block attends to the valid prefix (minus
# FAST) and only to itself.
prefix_att_2d = make_att_2d_masks(prefix_pad, prefix_att) # (B, P, P)
device = prefix_pad.device
prefix_rows = torch.cat(
[prefix_att_2d, torch.zeros(batch_size, prefix_len, k * chunk, dtype=torch.bool, device=device)],
dim=2,
)
action_to_prefix = prefix_pad[:, None, :].expand(batch_size, k * chunk, prefix_len).clone()
if fast_len > 0:
action_to_prefix[:, :, non_fast_prefix_len:prefix_len] = False
block_diag = torch.block_diag(
*[torch.ones(chunk, chunk, dtype=torch.bool, device=device) for _ in range(k)]
)
action_to_action = block_diag[None].expand(batch_size, k * chunk, k * chunk)
action_rows = torch.cat([action_to_prefix, action_to_action], dim=2)
att_2d = torch.cat([prefix_rows, action_rows], dim=1) # (B, P + k*chunk, P + k*chunk)
att_2d_4d = model._prepare_attention_masks_4d(att_2d, dtype=prefix_embs.dtype)
# Positions: prefix as usual; every block restarts at the prefix offset
# (each block is an independent denoising of the same chunk).
prefix_offsets = torch.sum(prefix_pad, dim=-1)[:, None]
block_positions = prefix_offsets + torch.cumsum(suffix_pad, dim=1) - 1 # (B, chunk)
position_ids = torch.cat([torch.cumsum(prefix_pad, dim=1) - 1, block_positions.repeat(1, k)], dim=1)
(prefix_out, suffix_out), _ = model.paligemma_with_expert.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
# ---- flow loss averaged over the K blocks -------------------
original_action_dim = self.config.output_features[ACTION].shape[0]
flow_accum: Tensor | None = None
for i in range(k):
block_out = suffix_out[:, i * chunk : (i + 1) * chunk].to(dtype=torch.float32)
v_t = model.action_out_proj(block_out)
flow_per_dim = F.mse_loss(u_t_blocks[i], v_t, reduction="none")[:, :, :original_action_dim]
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
flow_accum = per_sample_flow if flow_accum is None else flow_accum + per_sample_flow
per_sample_flow = flow_accum / k
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
return prefix_out, flow_loss
def _prefix_ce_losses(
self,
prefix_out: Tensor | None,
text_labels: Tensor | None,
action_tokens: Tensor | None,
action_code_mask: Tensor | None,
fast_len: int,
predict_actions_t: Tensor | None,
) -> tuple[Tensor | None, Tensor | None]:
"""Text-CE + FAST-CE from the VLM prefix hidden states.
Shared by the combined and amortized flow paths: slices the language
and FAST token positions out of ``prefix_out`` and runs the fused
linear-CE heads. Either loss is ``None`` when its inputs are absent."""
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
text_loss: Tensor | None = None
@@ -1412,7 +1564,7 @@ class PI052Policy(PreTrainedPolicy):
predict_actions_t,
)
return flow_loss, text_loss, fast_loss
return text_loss, fast_loss
def _compute_text_and_fast_loss(
self,
@@ -2241,57 +2393,68 @@ class PI052Policy(PreTrainedPolicy):
def get_optim_params(self):
"""Return policy parameters, optionally split into LR-scaled groups.
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
and its tied ``embed_tokens`` are placed in their own param
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
scheduler multiplies both groups by the same lambda each step,
so the ratio is preserved across decay. Default ``1.0`` =
return ``self.parameters()`` (back-compat with existing checkpoints
and configs).
Three orthogonal multipliers scale the base ``optimizer_lr``:
``lm_head_lr_scale`` (PaliGemma ``lm_head`` + tied ``embed_tokens``),
``backbone_lr_scale`` (the rest of the PaliGemma tower), and
``action_expert_lr_scale`` (the Gemma expert + action/time projection
heads). The cosine scheduler multiplies every group by the same lambda
each step so the ratios are preserved across decay. When all three are
``1.0`` this returns ``self.parameters()`` (back-compat with existing
checkpoints and configs).
"""
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
if scale == 1.0:
head_scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
backbone_scale = float(getattr(self.config, "backbone_lr_scale", 1.0))
expert_scale = float(getattr(self.config, "action_expert_lr_scale", 1.0))
if head_scale == 1.0 and backbone_scale == 1.0 and expert_scale == 1.0:
return self.parameters()
head_params: list[torch.nn.Parameter] = []
other_params: list[torch.nn.Parameter] = []
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
# boosting only the projection without the embedding pulls them
# apart and breaks the tie that PaliGemma was pre-trained with.
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` go in the
# head group — boosting only the projection without the embedding pulls
# them apart and breaks the tie PaliGemma was pre-trained with.
head_substrings = (
"paligemma_with_expert.paligemma.lm_head.",
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
)
backbone_substring = "paligemma_with_expert.paligemma."
head_params: list[torch.nn.Parameter] = []
backbone_params: list[torch.nn.Parameter] = []
expert_params: list[torch.nn.Parameter] = []
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if any(s in name for s in head_substrings):
head_params.append(p)
elif backbone_substring in name:
backbone_params.append(p)
else:
other_params.append(p)
expert_params.append(p)
base_lr = float(self.config.optimizer_lr)
groups: list[dict[str, object]] = []
if other_params:
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
if head_params:
if backbone_params:
groups.append(
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
{"params": backbone_params, "lr": base_lr * backbone_scale, "name": "backbone"}
)
# Sanity: head_substrings must match at least one parameter, otherwise
# the scale silently does nothing — surface that fast.
if not head_params:
if expert_params:
groups.append(
{"params": expert_params, "lr": base_lr * expert_scale, "name": "action_expert"}
)
if head_params:
groups.append({"params": head_params, "lr": base_lr * head_scale, "name": "lm_head"})
# Sanity: a non-trivial head scale that matches no params would silently
# do nothing — surface that fast.
if head_scale != 1.0 and not head_params:
raise RuntimeError(
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
"name patterns: "
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
f"name patterns: {head_substrings!r}. Did the underlying PaliGemma "
"module rename?"
)
logging.info(
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
"%d head params + %d other params",
scale,
"PI052Policy LR groups (base=%.3g): backbone=%.3g (×%.3g, n=%d), "
"action_expert=%.3g (×%.3g, n=%d), lm_head=%.3g (×%.3g, n=%d)",
base_lr,
base_lr * scale,
len(head_params),
len(other_params),
base_lr * backbone_scale, backbone_scale, len(backbone_params),
base_lr * expert_scale, expert_scale, len(expert_params),
base_lr * head_scale, head_scale, len(head_params),
)
return groups
+4 -1
View File
@@ -133,7 +133,10 @@ class PiGemmaRMSNorm(nn.Module):
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
modulation = self.dense(cond)
if len(x.shape) == 3:
# Per-sample cond (B, cond_dim) → broadcast over the sequence. A
# per-token cond (B, T, cond_dim) is already aligned with x and must
# not be unsqueezed (used by pi052's amortized K_repeat path).
if len(x.shape) == 3 and modulation.dim() == 2:
modulation = modulation.unsqueeze(1)
scale, shift, gate = modulation.chunk(3, dim=-1)
normed = normed * (1 + scale.float()) + shift.float()