mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
feat(pi052): amortized K_repeat flow + separate backbone/expert LRs
Two π0.5-paper training techniques for pi052: - flow_num_repeats (default 5): the action expert runs K independent noise/timestep draws against a single shared VLM prefix forward (tiled as block-diagonal suffix blocks with the FAST tokens masked out), amortizing the dominant backbone cost. Per-block flow losses are averaged so the backbone gradient stays well-scaled; pairs with knowledge_insulation (which additionally detaches the prefix K/V). flow_num_repeats=1 recovers the original single-draw combined forward. - backbone_lr_scale / action_expert_lr_scale: separate LR groups for the pretrained PaliGemma backbone vs the from-scratch action expert, on top of the existing lm_head_lr_scale. Defaults of 1.0 keep single-LR behaviour. PiGemmaRMSNorm now accepts per-token adaRMS conditioning so each tiled block carries its own timestep (2D per-sample cond is unchanged). Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -168,7 +168,7 @@ class PI052Config(PI05Config):
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
knowledge_insulation: bool = True
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
@@ -189,6 +189,30 @@ class PI052Config(PI05Config):
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# Separate LRs for the VLM backbone vs the action expert (paper §III.B).
|
||||
# The backbone is a pretrained PaliGemma; the action expert is trained
|
||||
# from scratch, so their initialisation scales differ and a single global
|
||||
# LR under-trains one of them. These multipliers scale the base
|
||||
# ``optimizer_lr`` for each group; the cosine scheduler applies the same
|
||||
# lambda to every group so the ratios hold across decay. ``backbone_lr_scale``
|
||||
# covers the PaliGemma tower (except the LM head / tied embeddings, which keep
|
||||
# their own ``lm_head_lr_scale``); ``action_expert_lr_scale`` covers the Gemma
|
||||
# expert plus the action/time projection heads. Defaults of 1.0 reproduce the
|
||||
# single-LR behaviour (back-compat with existing checkpoints/configs).
|
||||
backbone_lr_scale: float = 1.0
|
||||
action_expert_lr_scale: float = 1.0
|
||||
|
||||
# Amortized flow training (paper §III.B, K_repeat). The VLM/backbone forward
|
||||
# dominates step cost; to extract more learning signal per VLM pass the action
|
||||
# expert runs ``flow_num_repeats`` denoising targets per sample, each with an
|
||||
# independent noise + timestep draw, all attending to the single shared VLM
|
||||
# prefix. The per-repeat flow losses are averaged, so the backbone gradient
|
||||
# stays well-scaled. Pairs naturally with ``knowledge_insulation`` (which
|
||||
# additionally detaches the prefix K/V on the action path), the paper's
|
||||
# setting — but the amortized path is also correct without it. Set to 1 to
|
||||
# recover the original single-draw combined forward.
|
||||
flow_num_repeats: int = 5
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
@@ -250,3 +274,5 @@ class PI052Config(PI05Config):
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
if self.flow_num_repeats < 1:
|
||||
raise ValueError(f"flow_num_repeats must be >= 1, got {self.flow_num_repeats}")
|
||||
|
||||
@@ -1257,39 +1257,21 @@ class PI052Policy(PreTrainedPolicy):
|
||||
action_code_mask: Tensor | None,
|
||||
predict_actions_t: Tensor | None = None,
|
||||
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
||||
"""Flow + text + FAST losses, sharing a single VLM prefix forward.
|
||||
|
||||
Builds:
|
||||
|
||||
prefix = [images, language, FAST (when provided)]
|
||||
suffix = [noisy_actions] (action expert via gemma_expert)
|
||||
|
||||
Then overrides the unified 2D attention mask to *explicitly*
|
||||
zero out ``suffix → FAST`` attention. Without this override
|
||||
the action expert would attend to the discrete FAST tokens
|
||||
and trivially decode them back to the same continuous
|
||||
actions it's supposed to predict via flow matching — the
|
||||
whole training signal collapses.
|
||||
|
||||
Both prefix_out and suffix_out are captured from the same
|
||||
forward. From prefix_out we slice the language and FAST
|
||||
token positions and compute their CE losses. From suffix_out
|
||||
we run the existing flow path (action_out_proj → MSE).
|
||||
Embeds ``prefix = [images, language, FAST (when provided)]`` once, then
|
||||
computes the flow loss via either a single combined forward
|
||||
(``flow_num_repeats == 1``) or the amortized K-repeat path
|
||||
(``> 1``); both keep the discrete FAST tokens invisible to the action
|
||||
expert. The text/FAST CE losses are sliced from the shared
|
||||
``prefix_out``.
|
||||
|
||||
Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast
|
||||
can be ``None`` when the caller didn't supply the
|
||||
corresponding inputs.
|
||||
"""
|
||||
from lerobot.utils.constants import ACTION # noqa: PLC0415
|
||||
|
||||
|
||||
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
|
||||
actions = self.prepare_action(batch)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# ---- prefix: images + language + (optional FAST) -------------
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
@@ -1325,6 +1307,52 @@ class PI052Policy(PreTrainedPolicy):
|
||||
prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||
prefix_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||
|
||||
# ---- flow: one combined forward, or amortized over K repeats ----
|
||||
# ``flow_num_repeats == 1`` keeps the single combined [prefix; suffix]
|
||||
# forward. ``> 1`` runs the VLM prefix once and replays the action
|
||||
# expert K times against fresh noise/timestep draws, reusing the
|
||||
# cached prefix KV (paper §III.B). Both return ``prefix_out`` for the
|
||||
# shared text/FAST CE tail.
|
||||
num_repeats = int(getattr(self.config, "flow_num_repeats", 1))
|
||||
if num_repeats > 1:
|
||||
prefix_out, flow_loss = self._amortized_prefix_and_flow(
|
||||
actions, prefix_embs, prefix_pad, prefix_att,
|
||||
non_fast_prefix_len, fast_len, predict_actions_t, num_repeats,
|
||||
)
|
||||
else:
|
||||
prefix_out, flow_loss = self._combined_prefix_and_flow(
|
||||
actions, prefix_embs, prefix_pad, prefix_att,
|
||||
non_fast_prefix_len, fast_len, predict_actions_t,
|
||||
)
|
||||
|
||||
text_loss, fast_loss = self._prefix_ce_losses(
|
||||
prefix_out, text_labels, action_tokens, action_code_mask, fast_len, predict_actions_t
|
||||
)
|
||||
return flow_loss, text_loss, fast_loss
|
||||
|
||||
def _combined_prefix_and_flow(
|
||||
self,
|
||||
actions: Tensor,
|
||||
prefix_embs: Tensor,
|
||||
prefix_pad: Tensor,
|
||||
prefix_att: Tensor,
|
||||
non_fast_prefix_len: int,
|
||||
fast_len: int,
|
||||
predict_actions_t: Tensor | None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Single combined [prefix; suffix] forward → (prefix_out, flow_loss).
|
||||
|
||||
This is the original (``flow_num_repeats == 1``) path: one noise/time
|
||||
draw, one backbone forward producing both the VLM prefix hidden states
|
||||
(for text/FAST CE) and the action-expert suffix hidden states (flow)."""
|
||||
from lerobot.utils.constants import ACTION # noqa: PLC0415
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# ---- suffix: noisy actions ----------------------------------
|
||||
suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time)
|
||||
|
||||
@@ -1380,8 +1408,132 @@ class PI052Policy(PreTrainedPolicy):
|
||||
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
|
||||
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
|
||||
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
|
||||
return prefix_out, flow_loss
|
||||
|
||||
# ---- text + FAST CE from prefix_out ------------------------
|
||||
def _amortized_prefix_and_flow(
|
||||
self,
|
||||
actions: Tensor,
|
||||
prefix_embs: Tensor,
|
||||
prefix_pad: Tensor,
|
||||
prefix_att: Tensor,
|
||||
non_fast_prefix_len: int,
|
||||
fast_len: int,
|
||||
predict_actions_t: Tensor | None,
|
||||
num_repeats: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Amortized flow: one VLM prefix forward, K action-expert replays.
|
||||
|
||||
The VLM/backbone forward dominates step cost, so we keep a *single*
|
||||
combined forward but tile the action suffix into ``num_repeats`` blocks,
|
||||
each with an independent noise/timestep draw against the same action
|
||||
chunk (paper §III.B, K_repeat). The blocks attend to the shared prefix
|
||||
(FAST columns masked, exactly like the combined path) and are
|
||||
block-diagonal among themselves, so the expensive prefix K/V is computed
|
||||
once while the cheap action expert runs K times. Knowledge insulation
|
||||
(``_compute_layer_ki``) detaches the prefix K/V for the action queries,
|
||||
so this is gradient-equivalent to K independent draws sharing one VLM
|
||||
forward. Per-block flow losses are averaged.
|
||||
"""
|
||||
from lerobot.utils.constants import ACTION # noqa: PLC0415
|
||||
|
||||
model = self.model
|
||||
k = num_repeats
|
||||
chunk = self.config.chunk_size
|
||||
batch_size, prefix_len = prefix_pad.shape
|
||||
|
||||
first_layer = model.paligemma_with_expert.paligemma.model.language_model.layers[0]
|
||||
use_bf16 = first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16
|
||||
if use_bf16:
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# ---- K suffix blocks: independent noise/time draws ----------
|
||||
suffix_blocks: list[Tensor] = []
|
||||
adarms_blocks: list[Tensor] = []
|
||||
u_t_blocks: list[Tensor] = []
|
||||
suffix_pad = suffix_att = None
|
||||
for _ in range(k):
|
||||
noise = model.sample_noise(actions.shape, actions.device)
|
||||
time = model.sample_time(actions.shape[0], actions.device)
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t_blocks.append(noise - actions)
|
||||
s_embs, suffix_pad, suffix_att, adarms = model.embed_suffix(x_t, time)
|
||||
if use_bf16:
|
||||
s_embs = s_embs.to(dtype=torch.bfloat16)
|
||||
suffix_blocks.append(s_embs)
|
||||
# adaRMS time conditioning is per-sample; broadcast it across this
|
||||
# block's chunk tokens so each block carries its own timestep.
|
||||
adarms_blocks.append(adarms[:, None, :].expand(batch_size, chunk, adarms.shape[-1]))
|
||||
|
||||
suffix_embs = torch.cat(suffix_blocks, dim=1) # (B, k*chunk, D)
|
||||
adarms_cond = torch.cat(adarms_blocks, dim=1) # (B, k*chunk, cond_dim)
|
||||
|
||||
# ---- block-diagonal attention over [prefix | block_1..k] ----
|
||||
# Prefix rows keep their own (causal/text) attention and never see the
|
||||
# action blocks. Each action block attends to the valid prefix (minus
|
||||
# FAST) and only to itself.
|
||||
prefix_att_2d = make_att_2d_masks(prefix_pad, prefix_att) # (B, P, P)
|
||||
device = prefix_pad.device
|
||||
prefix_rows = torch.cat(
|
||||
[prefix_att_2d, torch.zeros(batch_size, prefix_len, k * chunk, dtype=torch.bool, device=device)],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
action_to_prefix = prefix_pad[:, None, :].expand(batch_size, k * chunk, prefix_len).clone()
|
||||
if fast_len > 0:
|
||||
action_to_prefix[:, :, non_fast_prefix_len:prefix_len] = False
|
||||
block_diag = torch.block_diag(
|
||||
*[torch.ones(chunk, chunk, dtype=torch.bool, device=device) for _ in range(k)]
|
||||
)
|
||||
action_to_action = block_diag[None].expand(batch_size, k * chunk, k * chunk)
|
||||
action_rows = torch.cat([action_to_prefix, action_to_action], dim=2)
|
||||
|
||||
att_2d = torch.cat([prefix_rows, action_rows], dim=1) # (B, P + k*chunk, P + k*chunk)
|
||||
att_2d_4d = model._prepare_attention_masks_4d(att_2d, dtype=prefix_embs.dtype)
|
||||
|
||||
# Positions: prefix as usual; every block restarts at the prefix offset
|
||||
# (each block is an independent denoising of the same chunk).
|
||||
prefix_offsets = torch.sum(prefix_pad, dim=-1)[:, None]
|
||||
block_positions = prefix_offsets + torch.cumsum(suffix_pad, dim=1) - 1 # (B, chunk)
|
||||
position_ids = torch.cat([torch.cumsum(prefix_pad, dim=1) - 1, block_positions.repeat(1, k)], dim=1)
|
||||
|
||||
(prefix_out, suffix_out), _ = model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
# ---- flow loss averaged over the K blocks -------------------
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
flow_accum: Tensor | None = None
|
||||
for i in range(k):
|
||||
block_out = suffix_out[:, i * chunk : (i + 1) * chunk].to(dtype=torch.float32)
|
||||
v_t = model.action_out_proj(block_out)
|
||||
flow_per_dim = F.mse_loss(u_t_blocks[i], v_t, reduction="none")[:, :, :original_action_dim]
|
||||
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
|
||||
flow_accum = per_sample_flow if flow_accum is None else flow_accum + per_sample_flow
|
||||
|
||||
per_sample_flow = flow_accum / k
|
||||
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
|
||||
return prefix_out, flow_loss
|
||||
|
||||
def _prefix_ce_losses(
|
||||
self,
|
||||
prefix_out: Tensor | None,
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_code_mask: Tensor | None,
|
||||
fast_len: int,
|
||||
predict_actions_t: Tensor | None,
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Text-CE + FAST-CE from the VLM prefix hidden states.
|
||||
|
||||
Shared by the combined and amortized flow paths: slices the language
|
||||
and FAST token positions out of ``prefix_out`` and runs the fused
|
||||
linear-CE heads. Either loss is ``None`` when its inputs are absent."""
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
text_loss: Tensor | None = None
|
||||
@@ -1412,7 +1564,7 @@ class PI052Policy(PreTrainedPolicy):
|
||||
predict_actions_t,
|
||||
)
|
||||
|
||||
return flow_loss, text_loss, fast_loss
|
||||
return text_loss, fast_loss
|
||||
|
||||
def _compute_text_and_fast_loss(
|
||||
self,
|
||||
@@ -2241,57 +2393,68 @@ class PI052Policy(PreTrainedPolicy):
|
||||
def get_optim_params(self):
|
||||
"""Return policy parameters, optionally split into LR-scaled groups.
|
||||
|
||||
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
|
||||
and its tied ``embed_tokens`` are placed in their own param
|
||||
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
|
||||
scheduler multiplies both groups by the same lambda each step,
|
||||
so the ratio is preserved across decay. Default ``1.0`` =
|
||||
return ``self.parameters()`` (back-compat with existing checkpoints
|
||||
and configs).
|
||||
Three orthogonal multipliers scale the base ``optimizer_lr``:
|
||||
``lm_head_lr_scale`` (PaliGemma ``lm_head`` + tied ``embed_tokens``),
|
||||
``backbone_lr_scale`` (the rest of the PaliGemma tower), and
|
||||
``action_expert_lr_scale`` (the Gemma expert + action/time projection
|
||||
heads). The cosine scheduler multiplies every group by the same lambda
|
||||
each step so the ratios are preserved across decay. When all three are
|
||||
``1.0`` this returns ``self.parameters()`` (back-compat with existing
|
||||
checkpoints and configs).
|
||||
"""
|
||||
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
|
||||
if scale == 1.0:
|
||||
head_scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
|
||||
backbone_scale = float(getattr(self.config, "backbone_lr_scale", 1.0))
|
||||
expert_scale = float(getattr(self.config, "action_expert_lr_scale", 1.0))
|
||||
if head_scale == 1.0 and backbone_scale == 1.0 and expert_scale == 1.0:
|
||||
return self.parameters()
|
||||
head_params: list[torch.nn.Parameter] = []
|
||||
other_params: list[torch.nn.Parameter] = []
|
||||
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
|
||||
# boosting only the projection without the embedding pulls them
|
||||
# apart and breaks the tie that PaliGemma was pre-trained with.
|
||||
|
||||
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` go in the
|
||||
# head group — boosting only the projection without the embedding pulls
|
||||
# them apart and breaks the tie PaliGemma was pre-trained with.
|
||||
head_substrings = (
|
||||
"paligemma_with_expert.paligemma.lm_head.",
|
||||
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
|
||||
)
|
||||
backbone_substring = "paligemma_with_expert.paligemma."
|
||||
head_params: list[torch.nn.Parameter] = []
|
||||
backbone_params: list[torch.nn.Parameter] = []
|
||||
expert_params: list[torch.nn.Parameter] = []
|
||||
for name, p in self.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if any(s in name for s in head_substrings):
|
||||
head_params.append(p)
|
||||
elif backbone_substring in name:
|
||||
backbone_params.append(p)
|
||||
else:
|
||||
other_params.append(p)
|
||||
expert_params.append(p)
|
||||
base_lr = float(self.config.optimizer_lr)
|
||||
groups: list[dict[str, object]] = []
|
||||
if other_params:
|
||||
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
|
||||
if head_params:
|
||||
if backbone_params:
|
||||
groups.append(
|
||||
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
|
||||
{"params": backbone_params, "lr": base_lr * backbone_scale, "name": "backbone"}
|
||||
)
|
||||
# Sanity: head_substrings must match at least one parameter, otherwise
|
||||
# the scale silently does nothing — surface that fast.
|
||||
if not head_params:
|
||||
if expert_params:
|
||||
groups.append(
|
||||
{"params": expert_params, "lr": base_lr * expert_scale, "name": "action_expert"}
|
||||
)
|
||||
if head_params:
|
||||
groups.append({"params": head_params, "lr": base_lr * head_scale, "name": "lm_head"})
|
||||
# Sanity: a non-trivial head scale that matches no params would silently
|
||||
# do nothing — surface that fast.
|
||||
if head_scale != 1.0 and not head_params:
|
||||
raise RuntimeError(
|
||||
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
|
||||
"name patterns: "
|
||||
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
|
||||
f"name patterns: {head_substrings!r}. Did the underlying PaliGemma "
|
||||
"module rename?"
|
||||
)
|
||||
logging.info(
|
||||
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
|
||||
"%d head params + %d other params",
|
||||
scale,
|
||||
"PI052Policy LR groups (base=%.3g): backbone=%.3g (×%.3g, n=%d), "
|
||||
"action_expert=%.3g (×%.3g, n=%d), lm_head=%.3g (×%.3g, n=%d)",
|
||||
base_lr,
|
||||
base_lr * scale,
|
||||
len(head_params),
|
||||
len(other_params),
|
||||
base_lr * backbone_scale, backbone_scale, len(backbone_params),
|
||||
base_lr * expert_scale, expert_scale, len(expert_params),
|
||||
base_lr * head_scale, head_scale, len(head_params),
|
||||
)
|
||||
return groups
|
||||
|
||||
|
||||
@@ -133,7 +133,10 @@ class PiGemmaRMSNorm(nn.Module):
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||
modulation = self.dense(cond)
|
||||
if len(x.shape) == 3:
|
||||
# Per-sample cond (B, cond_dim) → broadcast over the sequence. A
|
||||
# per-token cond (B, T, cond_dim) is already aligned with x and must
|
||||
# not be unsqueezed (used by pi052's amortized K_repeat path).
|
||||
if len(x.shape) == 3 and modulation.dim() == 2:
|
||||
modulation = modulation.unsqueeze(1)
|
||||
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||
normed = normed * (1 + scale.float()) + shift.float()
|
||||
|
||||
Reference in New Issue
Block a user