diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index b47117bb8..ff5ff395d 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -168,7 +168,7 @@ class PI052Config(PI05Config): # a per-instance monkey-patch on ``paligemma_with_expert.forward`` # that splits queries into VLM and action halves and ``.detach()``-s # the VLM K/V tensors used in the action-half's attention. - knowledge_insulation: bool = False + knowledge_insulation: bool = True """If True, route every transformer layer through the KI attention path that blocks action→VLM gradient flow on K/V.""" @@ -189,6 +189,30 @@ class PI052Config(PI05Config): # the same cosine lambda, so the 5x ratio is preserved across decay. lm_head_lr_scale: float = 5.0 + # Separate LRs for the VLM backbone vs the action expert (paper §III.B). + # The backbone is a pretrained PaliGemma; the action expert is trained + # from scratch, so their initialisation scales differ and a single global + # LR under-trains one of them. These multipliers scale the base + # ``optimizer_lr`` for each group; the cosine scheduler applies the same + # lambda to every group so the ratios hold across decay. ``backbone_lr_scale`` + # covers the PaliGemma tower (except the LM head / tied embeddings, which keep + # their own ``lm_head_lr_scale``); ``action_expert_lr_scale`` covers the Gemma + # expert plus the action/time projection heads. Defaults of 1.0 reproduce the + # single-LR behaviour (back-compat with existing checkpoints/configs). + backbone_lr_scale: float = 1.0 + action_expert_lr_scale: float = 1.0 + + # Amortized flow training (paper §III.B, K_repeat). The VLM/backbone forward + # dominates step cost; to extract more learning signal per VLM pass the action + # expert runs ``flow_num_repeats`` denoising targets per sample, each with an + # independent noise + timestep draw, all attending to the single shared VLM + # prefix. The per-repeat flow losses are averaged, so the backbone gradient + # stays well-scaled. Pairs naturally with ``knowledge_insulation`` (which + # additionally detaches the prefix K/V on the action path), the paper's + # setting — but the amortized path is also correct without it. Set to 1 to + # recover the original single-draw combined forward. + flow_num_repeats: int = 5 + # PaLM-style z-loss on text CE. Penalises the log-partition function # ``z = log Σ exp(logits)`` drifting away from zero — without it, large- # vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded @@ -250,3 +274,5 @@ class PI052Config(PI05Config): # out of text training via ``text_loss_weight=0``. if self.text_loss_weight > 0 and self.unfreeze_lm_head: self.train_expert_only = False + if self.flow_num_repeats < 1: + raise ValueError(f"flow_num_repeats must be >= 1, got {self.flow_num_repeats}") diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 0161e588b..e290ac0f8 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -1257,39 +1257,21 @@ class PI052Policy(PreTrainedPolicy): action_code_mask: Tensor | None, predict_actions_t: Tensor | None = None, ) -> tuple[Tensor, Tensor | None, Tensor | None]: - """Full fusion: flow + text + FAST in ONE backbone forward. + """Flow + text + FAST losses, sharing a single VLM prefix forward. - Builds: - - prefix = [images, language, FAST (when provided)] - suffix = [noisy_actions] (action expert via gemma_expert) - - Then overrides the unified 2D attention mask to *explicitly* - zero out ``suffix → FAST`` attention. Without this override - the action expert would attend to the discrete FAST tokens - and trivially decode them back to the same continuous - actions it's supposed to predict via flow matching — the - whole training signal collapses. - - Both prefix_out and suffix_out are captured from the same - forward. From prefix_out we slice the language and FAST - token positions and compute their CE losses. From suffix_out - we run the existing flow path (action_out_proj → MSE). + Embeds ``prefix = [images, language, FAST (when provided)]`` once, then + computes the flow loss via either a single combined forward + (``flow_num_repeats == 1``) or the amortized K-repeat path + (``> 1``); both keep the discrete FAST tokens invisible to the action + expert. The text/FAST CE losses are sliced from the shared + ``prefix_out``. Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast can be ``None`` when the caller didn't supply the corresponding inputs. """ - from lerobot.utils.constants import ACTION # noqa: PLC0415 - - # ---- preamble (mirrors PI05Pytorch.forward) ------------------ actions = self.prepare_action(batch) - noise = self.model.sample_noise(actions.shape, actions.device) - time = self.model.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions # ---- prefix: images + language + (optional FAST) ------------- images, img_masks = self._preprocess_images(batch) @@ -1325,6 +1307,52 @@ class PI052Policy(PreTrainedPolicy): prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1) prefix_att = torch.cat([prefix_att, ones_att], dim=1) + # ---- flow: one combined forward, or amortized over K repeats ---- + # ``flow_num_repeats == 1`` keeps the single combined [prefix; suffix] + # forward. ``> 1`` runs the VLM prefix once and replays the action + # expert K times against fresh noise/timestep draws, reusing the + # cached prefix KV (paper §III.B). Both return ``prefix_out`` for the + # shared text/FAST CE tail. + num_repeats = int(getattr(self.config, "flow_num_repeats", 1)) + if num_repeats > 1: + prefix_out, flow_loss = self._amortized_prefix_and_flow( + actions, prefix_embs, prefix_pad, prefix_att, + non_fast_prefix_len, fast_len, predict_actions_t, num_repeats, + ) + else: + prefix_out, flow_loss = self._combined_prefix_and_flow( + actions, prefix_embs, prefix_pad, prefix_att, + non_fast_prefix_len, fast_len, predict_actions_t, + ) + + text_loss, fast_loss = self._prefix_ce_losses( + prefix_out, text_labels, action_tokens, action_code_mask, fast_len, predict_actions_t + ) + return flow_loss, text_loss, fast_loss + + def _combined_prefix_and_flow( + self, + actions: Tensor, + prefix_embs: Tensor, + prefix_pad: Tensor, + prefix_att: Tensor, + non_fast_prefix_len: int, + fast_len: int, + predict_actions_t: Tensor | None, + ) -> tuple[Tensor, Tensor]: + """Single combined [prefix; suffix] forward → (prefix_out, flow_loss). + + This is the original (``flow_num_repeats == 1``) path: one noise/time + draw, one backbone forward producing both the VLM prefix hidden states + (for text/FAST CE) and the action-expert suffix hidden states (flow).""" + from lerobot.utils.constants import ACTION # noqa: PLC0415 + + noise = self.model.sample_noise(actions.shape, actions.device) + time = self.model.sample_time(actions.shape[0], actions.device) + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + # ---- suffix: noisy actions ---------------------------------- suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time) @@ -1380,8 +1408,132 @@ class PI052Policy(PreTrainedPolicy): flow_per_dim = flow_per_dim[:, :, :original_action_dim] per_sample_flow = flow_per_dim.mean(dim=(1, 2)) flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t) + return prefix_out, flow_loss - # ---- text + FAST CE from prefix_out ------------------------ + def _amortized_prefix_and_flow( + self, + actions: Tensor, + prefix_embs: Tensor, + prefix_pad: Tensor, + prefix_att: Tensor, + non_fast_prefix_len: int, + fast_len: int, + predict_actions_t: Tensor | None, + num_repeats: int, + ) -> tuple[Tensor, Tensor]: + """Amortized flow: one VLM prefix forward, K action-expert replays. + + The VLM/backbone forward dominates step cost, so we keep a *single* + combined forward but tile the action suffix into ``num_repeats`` blocks, + each with an independent noise/timestep draw against the same action + chunk (paper §III.B, K_repeat). The blocks attend to the shared prefix + (FAST columns masked, exactly like the combined path) and are + block-diagonal among themselves, so the expensive prefix K/V is computed + once while the cheap action expert runs K times. Knowledge insulation + (``_compute_layer_ki``) detaches the prefix K/V for the action queries, + so this is gradient-equivalent to K independent draws sharing one VLM + forward. Per-block flow losses are averaged. + """ + from lerobot.utils.constants import ACTION # noqa: PLC0415 + + model = self.model + k = num_repeats + chunk = self.config.chunk_size + batch_size, prefix_len = prefix_pad.shape + + first_layer = model.paligemma_with_expert.paligemma.model.language_model.layers[0] + use_bf16 = first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16 + if use_bf16: + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # ---- K suffix blocks: independent noise/time draws ---------- + suffix_blocks: list[Tensor] = [] + adarms_blocks: list[Tensor] = [] + u_t_blocks: list[Tensor] = [] + suffix_pad = suffix_att = None + for _ in range(k): + noise = model.sample_noise(actions.shape, actions.device) + time = model.sample_time(actions.shape[0], actions.device) + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t_blocks.append(noise - actions) + s_embs, suffix_pad, suffix_att, adarms = model.embed_suffix(x_t, time) + if use_bf16: + s_embs = s_embs.to(dtype=torch.bfloat16) + suffix_blocks.append(s_embs) + # adaRMS time conditioning is per-sample; broadcast it across this + # block's chunk tokens so each block carries its own timestep. + adarms_blocks.append(adarms[:, None, :].expand(batch_size, chunk, adarms.shape[-1])) + + suffix_embs = torch.cat(suffix_blocks, dim=1) # (B, k*chunk, D) + adarms_cond = torch.cat(adarms_blocks, dim=1) # (B, k*chunk, cond_dim) + + # ---- block-diagonal attention over [prefix | block_1..k] ---- + # Prefix rows keep their own (causal/text) attention and never see the + # action blocks. Each action block attends to the valid prefix (minus + # FAST) and only to itself. + prefix_att_2d = make_att_2d_masks(prefix_pad, prefix_att) # (B, P, P) + device = prefix_pad.device + prefix_rows = torch.cat( + [prefix_att_2d, torch.zeros(batch_size, prefix_len, k * chunk, dtype=torch.bool, device=device)], + dim=2, + ) + + action_to_prefix = prefix_pad[:, None, :].expand(batch_size, k * chunk, prefix_len).clone() + if fast_len > 0: + action_to_prefix[:, :, non_fast_prefix_len:prefix_len] = False + block_diag = torch.block_diag( + *[torch.ones(chunk, chunk, dtype=torch.bool, device=device) for _ in range(k)] + ) + action_to_action = block_diag[None].expand(batch_size, k * chunk, k * chunk) + action_rows = torch.cat([action_to_prefix, action_to_action], dim=2) + + att_2d = torch.cat([prefix_rows, action_rows], dim=1) # (B, P + k*chunk, P + k*chunk) + att_2d_4d = model._prepare_attention_masks_4d(att_2d, dtype=prefix_embs.dtype) + + # Positions: prefix as usual; every block restarts at the prefix offset + # (each block is an independent denoising of the same chunk). + prefix_offsets = torch.sum(prefix_pad, dim=-1)[:, None] + block_positions = prefix_offsets + torch.cumsum(suffix_pad, dim=1) - 1 # (B, chunk) + position_ids = torch.cat([torch.cumsum(prefix_pad, dim=1) - 1, block_positions.repeat(1, k)], dim=1) + + (prefix_out, suffix_out), _ = model.paligemma_with_expert.forward( + attention_mask=att_2d_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + # ---- flow loss averaged over the K blocks ------------------- + original_action_dim = self.config.output_features[ACTION].shape[0] + flow_accum: Tensor | None = None + for i in range(k): + block_out = suffix_out[:, i * chunk : (i + 1) * chunk].to(dtype=torch.float32) + v_t = model.action_out_proj(block_out) + flow_per_dim = F.mse_loss(u_t_blocks[i], v_t, reduction="none")[:, :, :original_action_dim] + per_sample_flow = flow_per_dim.mean(dim=(1, 2)) + flow_accum = per_sample_flow if flow_accum is None else flow_accum + per_sample_flow + + per_sample_flow = flow_accum / k + flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t) + return prefix_out, flow_loss + + def _prefix_ce_losses( + self, + prefix_out: Tensor | None, + text_labels: Tensor | None, + action_tokens: Tensor | None, + action_code_mask: Tensor | None, + fast_len: int, + predict_actions_t: Tensor | None, + ) -> tuple[Tensor | None, Tensor | None]: + """Text-CE + FAST-CE from the VLM prefix hidden states. + + Shared by the combined and amortized flow paths: slices the language + and FAST token positions out of ``prefix_out`` and runs the fused + linear-CE heads. Either loss is ``None`` when its inputs are absent.""" lm_head = self.model.paligemma_with_expert.paligemma.lm_head text_loss: Tensor | None = None @@ -1412,7 +1564,7 @@ class PI052Policy(PreTrainedPolicy): predict_actions_t, ) - return flow_loss, text_loss, fast_loss + return text_loss, fast_loss def _compute_text_and_fast_loss( self, @@ -2241,57 +2393,68 @@ class PI052Policy(PreTrainedPolicy): def get_optim_params(self): """Return policy parameters, optionally split into LR-scaled groups. - When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head`` - and its tied ``embed_tokens`` are placed in their own param - group with ``lr = base_lr * lm_head_lr_scale``. The cosine - scheduler multiplies both groups by the same lambda each step, - so the ratio is preserved across decay. Default ``1.0`` = - return ``self.parameters()`` (back-compat with existing checkpoints - and configs). + Three orthogonal multipliers scale the base ``optimizer_lr``: + ``lm_head_lr_scale`` (PaliGemma ``lm_head`` + tied ``embed_tokens``), + ``backbone_lr_scale`` (the rest of the PaliGemma tower), and + ``action_expert_lr_scale`` (the Gemma expert + action/time projection + heads). The cosine scheduler multiplies every group by the same lambda + each step so the ratios are preserved across decay. When all three are + ``1.0`` this returns ``self.parameters()`` (back-compat with existing + checkpoints and configs). """ - scale = float(getattr(self.config, "lm_head_lr_scale", 1.0)) - if scale == 1.0: + head_scale = float(getattr(self.config, "lm_head_lr_scale", 1.0)) + backbone_scale = float(getattr(self.config, "backbone_lr_scale", 1.0)) + expert_scale = float(getattr(self.config, "action_expert_lr_scale", 1.0)) + if head_scale == 1.0 and backbone_scale == 1.0 and expert_scale == 1.0: return self.parameters() - head_params: list[torch.nn.Parameter] = [] - other_params: list[torch.nn.Parameter] = [] - # Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` — - # boosting only the projection without the embedding pulls them - # apart and breaks the tie that PaliGemma was pre-trained with. + + # Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` go in the + # head group — boosting only the projection without the embedding pulls + # them apart and breaks the tie PaliGemma was pre-trained with. head_substrings = ( "paligemma_with_expert.paligemma.lm_head.", "paligemma_with_expert.paligemma.model.language_model.embed_tokens.", ) + backbone_substring = "paligemma_with_expert.paligemma." + head_params: list[torch.nn.Parameter] = [] + backbone_params: list[torch.nn.Parameter] = [] + expert_params: list[torch.nn.Parameter] = [] for name, p in self.named_parameters(): if not p.requires_grad: continue if any(s in name for s in head_substrings): head_params.append(p) + elif backbone_substring in name: + backbone_params.append(p) else: - other_params.append(p) + expert_params.append(p) base_lr = float(self.config.optimizer_lr) groups: list[dict[str, object]] = [] - if other_params: - groups.append({"params": other_params, "lr": base_lr, "name": "policy"}) - if head_params: + if backbone_params: groups.append( - {"params": head_params, "lr": base_lr * scale, "name": "lm_head"} + {"params": backbone_params, "lr": base_lr * backbone_scale, "name": "backbone"} ) - # Sanity: head_substrings must match at least one parameter, otherwise - # the scale silently does nothing — surface that fast. - if not head_params: + if expert_params: + groups.append( + {"params": expert_params, "lr": base_lr * expert_scale, "name": "action_expert"} + ) + if head_params: + groups.append({"params": head_params, "lr": base_lr * head_scale, "name": "lm_head"}) + # Sanity: a non-trivial head scale that matches no params would silently + # do nothing — surface that fast. + if head_scale != 1.0 and not head_params: raise RuntimeError( "lm_head_lr_scale != 1.0 but no parameters matched the LM-head " - "name patterns: " - f"{head_substrings!r}. Did the underlying PaliGemma module rename?" + f"name patterns: {head_substrings!r}. Did the underlying PaliGemma " + "module rename?" ) logging.info( - "PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over " - "%d head params + %d other params", - scale, + "PI052Policy LR groups (base=%.3g): backbone=%.3g (×%.3g, n=%d), " + "action_expert=%.3g (×%.3g, n=%d), lm_head=%.3g (×%.3g, n=%d)", base_lr, - base_lr * scale, - len(head_params), - len(other_params), + base_lr * backbone_scale, backbone_scale, len(backbone_params), + base_lr * expert_scale, expert_scale, len(expert_params), + base_lr * head_scale, head_scale, len(head_params), ) return groups diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py index c8631cbff..6201b0148 100644 --- a/src/lerobot/policies/pi_gemma.py +++ b/src/lerobot/policies/pi_gemma.py @@ -133,7 +133,10 @@ class PiGemmaRMSNorm(nn.Module): if cond.shape[-1] != self.cond_dim: raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}") modulation = self.dense(cond) - if len(x.shape) == 3: + # Per-sample cond (B, cond_dim) → broadcast over the sequence. A + # per-token cond (B, T, cond_dim) is already aligned with x and must + # not be unsqueezed (used by pi052's amortized K_repeat path). + if len(x.shape) == 3 and modulation.dim() == 2: modulation = modulation.unsqueeze(1) scale, shift, gate = modulation.chunk(3, dim=-1) normed = normed * (1 + scale.float()) + shift.float()