diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 5cab001a9..f657c84b0 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -99,6 +99,50 @@ class PI052Config(PI05Config): memory_dropout_prob: float = 0.0 subtask_dropout_prob: float = 0.0 + # FAST discrete-action supervision — paper §III.B-C ------------------ + # When enabled, actions are *also* tokenised via the FAST tokenizer + # ("physical-intelligence/fast") and supervised with cross-entropy + # on the PaliGemma LM head — exactly as in the paper's pre-training + # objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The + # ActionTokenizerProcessorStep is wired into the preprocessor + # pipeline when this flag is set; the loss is computed in + # PI052Policy.forward. + enable_fast_action_loss: bool = False + """If True, tokenise actions with the FAST tokenizer and add a + cross-entropy loss on the LM head. Off by default because most + fine-tuning runs only need the flow head + text supervision; the + FAST CE term is most useful when training from a base PaliGemma + rather than an existing π0.5 checkpoint.""" + + action_tokenizer_name: str = "physical-intelligence/fast" + """HF identifier for the FAST action tokenizer.""" + + max_action_tokens: int = 256 + """Maximum number of FAST tokens per action chunk.""" + + fast_skip_tokens: int = 128 + """Number of low-vocab tokens the FAST tokenizer skips to avoid + collisions with PaliGemma's text vocabulary.""" + + fast_action_loss_weight: float = 1.0 + """Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0.""" + + # Knowledge insulation — paper §III.B -------------------------------- + # When enabled, gradients from the action expert's flow loss are + # *blocked* from flowing back into the VLM's K/V projections. This + # prevents the action loss from over-fitting the language backbone + # to robot-specific features. Implementation requires a custom + # per-layer attention forward that uses ``.detach()`` on VLM K/V + # when computing attention for action queries — see + # ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation`` + # on branch ``feat/add-pi05`` for the reference implementation. + knowledge_insulation: bool = False + """If True, route every transformer layer through the KI + attention path (blocks action→VLM gradient flow on K/V). + Currently a no-op stub in this branch — PR follow-up required to + port the full custom layer from pi05_full. The flag is exposed + here so SLURM commands can be written against the final shape.""" + def __post_init__(self) -> None: super().__post_init__() # Backbone needs gradients flowing through the text head when diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 7f6f7cc86..b5e4e6054 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -39,10 +39,12 @@ from __future__ import annotations import logging import math +import types from typing import Any import torch from torch import Tensor +from torch.nn import functional as F from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS @@ -53,6 +55,190 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) +# ---------------------------------------------------------------------- +# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``) +# ---------------------------------------------------------------------- +# +# Per-layer attention that splits the queries into VLM and action +# parts, computing attention for action queries with .detach()'d VLM +# K/V so the action loss's gradient cannot flow back into the VLM's K +# and V projections. Forward output is bit-equivalent to the standard +# layer; backward differs only on the path action_loss → VLM K/V. + +def _compute_layer_ki( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma, + gemma_expert, +): + from transformers.models.gemma import modeling_gemma # noqa: PLC0415 + + models = [paligemma.language_model, gemma_expert.model] + query_states, key_states, value_states, gates = [], [], [], [] + + vlm_len = inputs_embeds[0].shape[1] + + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + q = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + k = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + v = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(q) + key_states.append(k) + value_states.append(v) + + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + + dummy = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Split queries / K / V at the VLM-vs-action boundary. + Q_vlm = query_states[:, :, :vlm_len, :] + Q_action = query_states[:, :, vlm_len:, :] + K_vlm = key_states[:, :, :vlm_len, :] + K_action = key_states[:, :, vlm_len:, :] + V_vlm = value_states[:, :, :vlm_len, :] + V_action = value_states[:, :, vlm_len:, :] + + # Detach VLM K/V *only* on the path the action queries use. + K_vlm_det = K_vlm.detach() + V_vlm_det = V_vlm.detach() + K_for_vlm = key_states # full (gradients flow) + V_for_vlm = value_states + K_for_action = torch.cat([K_vlm_det, K_action], dim=2) + V_for_action = torch.cat([V_vlm_det, V_action], dim=2) + + mask_for_vlm = attention_mask[:, :, :vlm_len, :] + mask_for_action = attention_mask[:, :, vlm_len:, :] + + att_vlm, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling, + ) + att_action, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + Q_action, K_for_action, V_for_action, mask_for_action, scaling, + ) + att = torch.cat([att_vlm, att_action], dim=1) + + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att = att.reshape(batch_size, -1, 1 * 8 * head_dim) + + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end = start + hidden_states.shape[1] + if att.dtype != layer.self_attn.o_proj.weight.dtype: + att = att.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att[:, start:end]) + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i]) + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start = end + return outputs_embeds + + +def _paligemma_forward_ki( + self, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + adarms_cond=None, + fill_kv_cache=None, +): + """Replacement ``PaliGemmaWithExpertModel.forward`` that routes the + dual-expert layer pass through :func:`_compute_layer_ki`. + + Bound onto the model instance when ``config.knowledge_insulation`` + is True (see ``PI052Policy.__init__``). Single-expert branches + (VLM-only or action-only) reuse the parent's implementation + because there's no KI signal to add — KI only matters when + actions and VLM tokens are forwarded together. + """ + from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415 + + if adarms_cond is None: + adarms_cond = [None, None] + + # Single-expert paths: defer to the bound class method via super(). + if inputs_embeds[0] is None or inputs_embeds[1] is None: + return type(self).__bases__[0].forward( + self, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + adarms_cond=adarms_cond, + ) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + adarms_cond=adarms_cond, + ) + + models = [self.paligemma.model.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + use_gc = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or ( + hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training + ) + + for layer_idx in range(num_layers): + if use_gc: + inputs_embeds = torch.utils.checkpoint.checkpoint( + _compute_layer_ki, + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, + use_reentrant=False, preserve_rng_state=False, + paligemma=self.paligemma, gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = _compute_layer_ki( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, + paligemma=self.paligemma, gemma_expert=self.gemma_expert, + ) + + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + outputs_embeds.append(out_emb) + return [outputs_embeds[0], outputs_embeds[1]], None + + class PI052Policy(PI05Policy): """π0.5 with the PaliGemma LM head re-enabled.""" @@ -68,6 +254,20 @@ class PI052Policy(PI05Policy): if config.text_loss_weight > 0 and config.unfreeze_lm_head: self._unfreeze_lm_head() + # Knowledge insulation: bind a custom ``forward`` on the + # PaliGemmaWithExpertModel instance that uses + # :func:`_compute_layer_ki` for the dual-expert layer pass. + # The bind is per-instance, so this doesn't leak into stock + # ``pi05`` policies that share the same class. + if getattr(config, "knowledge_insulation", False): + backbone = self.model.paligemma_with_expert + backbone._pi052_orig_forward = backbone.forward + backbone.forward = types.MethodType(_paligemma_forward_ki, backbone) + logger.info( + "PI052: knowledge insulation enabled — action→VLM K/V " + "gradients are blocked in attention." + ) + # ------------------------------------------------------------------ # Head unfreeze helper # ------------------------------------------------------------------ @@ -143,6 +343,22 @@ class PI052Policy(PI05Policy): else total + self.config.text_loss_weight * text_loss ) + # FAST action-token CE loss (paper §III.C). When + # ``enable_fast_action_loss=True`` the preprocessor wrote + # ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we + # forward them through the PaliGemma backbone alongside the + # language prefix and compute CE on the action positions. + if getattr(self.config, "enable_fast_action_loss", False): + from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415 + + action_tokens = batch.get(ACTION_TOKENS) + action_mask = batch.get(ACTION_TOKEN_MASK) + if action_tokens is not None and action_mask is not None: + fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask) + loss_dict["fast_action_loss"] = float(fast_loss.detach().item()) + weighted = self.config.fast_action_loss_weight * fast_loss + total = weighted if total is None else total + weighted + if total is None: # Both flow and text disabled — make this an obvious bug # rather than a silent zero loss. @@ -161,6 +377,82 @@ class PI052Policy(PI05Policy): # Text loss # ------------------------------------------------------------------ + def _compute_fast_action_loss( + self, + batch: dict[str, Tensor], + action_tokens: Tensor, + action_mask: Tensor, + ) -> Tensor: + """Cross-entropy on FAST-tokenised actions via PaliGemma's LM head. + + Mirrors the paper's §III.C action-token loss: append the FAST + tokens to the language prefix, forward through the backbone, + slice the per-token logits at the action positions, and + compute CE against the FAST targets (shifted for next-token + prediction). Token-mask gates the loss to valid positions. + """ + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + images, img_masks = self.model._preprocess_images(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + # Re-embed prefix to get [images, language] embeddings, then + # append the FAST action token embeddings as additional + # prefix tokens. PaliGemma's embed_language_tokens is shared + # between text and FAST tokens so we can re-use it directly. + prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + emb_dim = prefix_embs.shape[-1] + fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) + fast_emb = fast_emb * math.sqrt(emb_dim) + + # Concat onto the prefix. Pad masks: language uses + # ``lang_masks``; FAST uses ``action_mask`` (True at valid + # token positions). Attention masks add ``True`` (causal) + # for FAST so they can attend to the bidirectional prefix + # but only causally among themselves. + bsize, fast_len = action_tokens.shape + device = prefix_embs.device + ones_att = torch.ones((bsize, fast_len), dtype=torch.bool, device=device) + full_embs = torch.cat([prefix_embs, fast_emb], dim=1) + full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1) + full_att = torch.cat([prefix_att, ones_att], dim=1) + + att_2d = make_att_2d_masks(full_pad, full_att) + position_ids = torch.cumsum(full_pad, dim=1) - 1 + + (vlm_out, _), _ = self.model.paligemma_with_expert.forward( + attention_mask=att_2d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[full_embs, None], + use_cache=False, + fill_kv_cache=True, + ) + if vlm_out is None: + raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.") + + # Slice the last ``fast_len`` positions — those correspond to + # the FAST tokens we just appended. + fast_hidden = vlm_out[:, -fast_len:, :] + lm_head = self.model.paligemma_with_expert.paligemma.lm_head + fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) + + # Shift for next-token prediction. + shift_logits = fast_logits[:, :-1, :].contiguous() + shift_targets = action_tokens[:, 1:].contiguous() + shift_mask = action_mask[:, 1:].contiguous().float() + + per_tok = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_targets.view(-1), + reduction="none", + ).view(shift_targets.shape) + loss = (per_tok * shift_mask).sum() / shift_mask.sum().clamp_min(1.0) + return loss + def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: """Cross-entropy on PaliGemma's LM head over the supervised span. diff --git a/src/lerobot/policies/pi052/processor_pi052.py b/src/lerobot/policies/pi052/processor_pi052.py index 6abe1cdcd..7c3f9c4eb 100644 --- a/src/lerobot/policies/pi052/processor_pi052.py +++ b/src/lerobot/policies/pi052/processor_pi052.py @@ -43,6 +43,7 @@ import torch from lerobot.configs.recipe import TrainingRecipe from lerobot.processor import ( AbsoluteActionsProcessorStep, + ActionTokenizerProcessorStep, AddBatchDimensionProcessorStep, DeviceProcessorStep, NormalizerProcessorStep, @@ -101,9 +102,25 @@ def make_pi052_pre_post_processors( memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0), subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0), ), - DeviceProcessorStep(device=config.device), ] + # FAST tokenizer for discrete-action CE supervision (paper §III.C). + # Only inserted when explicitly enabled — keeps the post-training- + # style recipe (flow + text) as the default. When on, the step + # writes ACTION_TOKENS / ACTION_TOKEN_MASK into + # ``COMPLEMENTARY_DATA`` and the modeling forward picks them up. + if getattr(config, "enable_fast_action_loss", False): + input_steps.append( + ActionTokenizerProcessorStep( + action_tokenizer_name=config.action_tokenizer_name, + max_action_tokens=config.max_action_tokens, + fast_skip_tokens=config.fast_skip_tokens, + paligemma_tokenizer_name="google/paligemma-3b-pt-224", + ) + ) + + input_steps.append(DeviceProcessorStep(device=config.device)) + output_steps = [ UnnormalizerProcessorStep( features=config.output_features,