From a594ad796989a39d91e29927d33898a230aabcb6 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Thu, 4 Jun 2026 19:59:27 +0200 Subject: [PATCH] refactor(pi052): self-contained policy; revert pi0/pi05 to upstream main The smolvla branch had modified the shared pi0/pi05 modeling + pi05 config to support pi052 (SDPA attention, layernorm/lm_head handling, optimizer foreach/fused/lm_head_lr_scale, embedding scaling). Decouple pi052 instead: - Vendor the PI0.5 backbone (PaliGemmaWithExpertModel, PI05Pytorch, helpers) into pi052/pi05_backbone.py (verbatim copy, no PI05Policy). - Flatten PI052Policy to subclass PreTrainedPolicy directly (no longer PI05Policy); inline the needed PI05Policy methods. - Restore optimizer_foreach/fused + get_optimizer_preset on PI052Config. - Revert pi0, pi0_fast, pi05 modeling and configuration_pi05 to origin/main (byte-identical), so the shared policies carry no smolvla modifications. Behavior verified bit-exact on pepijn223/pi052_robocasa_full: embed_language_ tokens, predict_action_chunk, and the fused flow+text+FAST training loss are identical before/after (max_abs_diff=0). pi052 tests pass (pre-existing stale-name collection errors unchanged). Co-authored-by: Cursor --- src/lerobot/policies/pi0/modeling_pi0.py | 6 +- .../policies/pi05/configuration_pi05.py | 17 - src/lerobot/policies/pi05/modeling_pi05.py | 220 ++-- .../policies/pi052/configuration_pi052.py | 19 + src/lerobot/policies/pi052/modeling_pi052.py | 475 ++++++++- src/lerobot/policies/pi052/pi05_backbone.py | 947 ++++++++++++++++++ .../policies/pi0_fast/modeling_pi0_fast.py | 2 - .../pi052/test_pi052_sdpa_attention.py | 2 +- 8 files changed, 1484 insertions(+), 204 deletions(-) create mode 100644 src/lerobot/policies/pi052/pi05_backbone.py diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 753b8d45f..f6f4212fb 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -677,10 +677,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(lang_tokens): - # embed_language_tokens -> Gemma get_input_embeddings(), which is - # GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already multiplies by - # sqrt(hidden_size) internally. Do NOT scale again here (would double-scale text). - return self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 2d4c0e4d8..124e85cc9 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -93,21 +93,6 @@ class PI05Config(PreTrainedConfig): optimizer_eps: float = 1e-8 optimizer_weight_decay: float = 0.01 optimizer_grad_clip_norm: float = 1.0 - optimizer_foreach: bool | None = False - optimizer_fused: bool | None = True - - # LM-head LR multiplier. The PaliGemma `lm_head` projection (and its - # tied `embed_tokens`) is the surface the LM head's first-token - # distribution depends on. With ``knowledge_insulation`` blocking - # action→VLM gradients, the LM head only sees gradients on text-CE - # samples — which can be a small fraction of the mix (e.g. ~45% in - # ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's - # first-token distribution can drift back toward PaliGemma's - # pretrained ```` detection prior, despite teacher-forced CE - # staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps - # the head pinned to fine-tuning targets without perturbing the - # backbone / vision tower / action expert. Default 1.0 = no change. - lm_head_lr_scale: float = 1.0 # Scheduler settings: see openpi `CosineDecaySchedule` # Note: These will auto-scale if --steps < scheduler_decay_steps @@ -167,8 +152,6 @@ class PI05Config(PreTrainedConfig): eps=self.optimizer_eps, weight_decay=self.optimizer_weight_decay, grad_clip_norm=self.optimizer_grad_clip_norm, - foreach=self.optimizer_foreach, - fused=self.optimizer_fused, ) def get_scheduler_preset(self): diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 3c30cdeb8..aabd04c6f 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -15,7 +15,6 @@ # limitations under the License. import builtins -import copy import logging import math from collections import deque @@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: + from transformers.cache_utils import DynamicCache from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma @@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available: ) else: CONFIG_MAPPING = None + DynamicCache = None modeling_gemma = None PiGemmaForCausalLM = None _gated_residual = None @@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks +def clone_past_key_values(past_key_values): + """Clone the DynamicCache returned by prefix prefill for compiled denoising.""" + return DynamicCache( + tuple( + (keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values + ) + ) + + def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. @@ -223,53 +233,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images -def sdpa_attention_forward( - module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, -): - """Drop-in for ``modeling_gemma.eager_attention_forward`` using - ``torch.nn.functional.scaled_dot_product_attention``. - - PyTorch SDPA picks the memory-efficient kernel for arbitrary additive - bias masks (the FA backend only accepts causal/sliding-window). On - H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory - than the eager softmax(QK^T)+matmul path. Mirrors eager's signature - and output shape (``(B, Lq, H, D)``) so call sites are unchanged. - """ - n_rep = module.num_key_value_groups - if n_rep > 1: - key = key.repeat_interleave(n_rep, dim=1) - value = value.repeat_interleave(n_rep, dim=1) - if attention_mask is not None and attention_mask.dtype != query.dtype: - attention_mask = attention_mask.to(dtype=query.dtype) - attn_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=dropout if module.training else 0.0, - is_causal=False, - scale=scaling, - ) - return attn_output.transpose(1, 2).contiguous(), None - - # Define the complete layer computation function for gradient checkpointing -def compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert -): - models = [paligemma.model.language_model, gemma_expert.model] +def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb): query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -291,14 +262,16 @@ def compute_layer_complete( device=query_states.device, dtype=query_states.dtype, ) - cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + cos, sin = rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling - att_output, _ = sdpa_attention_forward( - paligemma.model.language_model.layers[layer_idx].self_attn, + paligemma_layer = layers[0] + scaling = paligemma_layer.self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma_layer.self_attn, query_states, key_states, value_states, @@ -306,13 +279,13 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma_layer.self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) @@ -444,7 +417,6 @@ class PaliGemmaWithExpertModel( params_to_keep_float32 = [ "vision_tower", "multi_modal_projector", - "lm_head", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -477,17 +449,13 @@ class PaliGemmaWithExpertModel( if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - # OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the - # Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base - # was trained in this regime, so scaling image features here over-scales them ~45x and - # breaks the pretrained vision-language alignment. Keep image features un-normalized. features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( self, @@ -525,8 +493,9 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.model.language_model, self.gemma_expert.model] - num_layers = self.paligemma.config.text_config.num_hidden_layers + paligemma_layers = self.paligemma.model.language_model.layers + gemma_expert_layers = self.gemma_expert.model.layers + rotary_emb = self.paligemma.model.language_model.rotary_emb # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( @@ -536,36 +505,39 @@ class PaliGemmaWithExpertModel( ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled - for layer_idx in range(num_layers): + for layers in zip(paligemma_layers, gemma_expert_layers, strict=True): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) else: inputs_embeds = compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) # final norm + final_norms = ( + self.paligemma.model.language_model.norm, + self.gemma_expert.model.norm, + ) + def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -657,13 +629,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) return func(*args, **kwargs) - def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): + def _prepare_attention_masks_4d(self, att_2d_masks): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] - result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) - if dtype is not None: - result = result.to(dtype=dtype) - return result + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) def sample_noise(self, shape, device): return torch.normal( @@ -704,10 +673,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(tokens): - # embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding - # (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT - # scale again here or text tokens get double-scaled (~45x) and break alignment. - return self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) @@ -794,22 +761,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) - # Selective AC: rely on the per-layer checkpoint inside - # ``PaliGemmaWithExpertModel.forward`` (which wraps each - # transformer block individually). The previous outer - # ``_apply_checkpoint(forward_func, ...)`` doubled up — it - # re-ran the full backbone forward during backward *and* each - # block's own checkpoint re-ran during that recompute. Pure - # waste with SDPA, which already streams attention activations. - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) suffix_out = suffix_out[:, -self.config.chunk_size :] @@ -853,9 +819,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d( - prefix_att_2d_masks, dtype=prefix_embs.dtype - ) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( @@ -925,12 +889,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - full_att_2d_masks_4d = self._prepare_attention_masks_4d( - full_att_2d_masks, dtype=suffix_embs.dtype - ) + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - past_key_values = copy.deepcopy(past_key_values) + past_key_values = clone_past_key_values(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, @@ -1065,16 +1027,6 @@ class PI05Policy(PreTrainedPolicy): if remap_count > 0: print(f"Remapped {remap_count} state dict keys") - lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight" - embed_tokens_key = ( - "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" - ) - if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict: - remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float() - print("Initialized PaliGemma lm_head from language token embeddings") - elif lm_head_key in remapped_state_dict: - remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float() - # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) @@ -1168,62 +1120,8 @@ class PI05Policy(PreTrainedPolicy): return fixed_state_dict - def get_optim_params(self): - """Return policy parameters, optionally split into LR-scaled groups. - - When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head`` - and its tied ``embed_tokens`` are placed in their own param - group with ``lr = base_lr * lm_head_lr_scale``. The cosine - scheduler multiplies both groups by the same lambda each step, - so the ratio is preserved across decay. Default ``1.0`` = - return ``self.parameters()`` (back-compat with existing checkpoints - and configs). - """ - scale = float(getattr(self.config, "lm_head_lr_scale", 1.0)) - if scale == 1.0: - return self.parameters() - head_params: list[torch.nn.Parameter] = [] - other_params: list[torch.nn.Parameter] = [] - # Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` — - # boosting only the projection without the embedding pulls them - # apart and breaks the tie that PaliGemma was pre-trained with. - head_substrings = ( - "paligemma_with_expert.paligemma.lm_head.", - "paligemma_with_expert.paligemma.model.language_model.embed_tokens.", - ) - for name, p in self.named_parameters(): - if not p.requires_grad: - continue - if any(s in name for s in head_substrings): - head_params.append(p) - else: - other_params.append(p) - base_lr = float(self.config.optimizer_lr) - groups: list[dict[str, object]] = [] - if other_params: - groups.append({"params": other_params, "lr": base_lr, "name": "policy"}) - if head_params: - groups.append( - {"params": head_params, "lr": base_lr * scale, "name": "lm_head"} - ) - # Sanity: head_substrings must match at least one parameter, otherwise - # the scale silently does nothing — surface that fast. - if not head_params: - raise RuntimeError( - "lm_head_lr_scale != 1.0 but no parameters matched the LM-head " - "name patterns: " - f"{head_substrings!r}. Did the underlying PaliGemma module rename?" - ) - logging.info( - "PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over " - "%d head params + %d other params", - scale, - base_lr, - base_lr * scale, - len(head_params), - len(other_params), - ) - return groups + def get_optim_params(self) -> dict: + return self.parameters() def reset(self): """Reset internal state - called when environment resets.""" diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index d2725891f..2b1576929 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -39,6 +39,7 @@ step, which the multi-rate ``PI052Runtime`` (in from dataclasses import dataclass from lerobot.configs import PreTrainedConfig +from lerobot.optim.optimizers import AdamWConfig from ..pi05.configuration_pi05 import PI05Config @@ -206,6 +207,24 @@ class PI052Config(PI05Config): ``DecodingError: The fields use_hf_kernels are not valid for PI052Config`` (job 22164492). Remove in a future major bump.""" + # Optimizer foreach/fused. pi052 carries these locally because the shared + # PI05Config (kept identical to upstream main) does not define them; the + # checkpoints we train serialize both keys into config.json, so they must + # be valid PI052Config fields and flow into the AdamW preset below. + optimizer_foreach: bool | None = False + optimizer_fused: bool | None = True + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + foreach=self.optimizer_foreach, + fused=self.optimizer_fused, + ) + def __post_init__(self) -> None: super().__post_init__() # Backbone needs gradients flowing through the text head when diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 5cc7db457..2d377c648 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -37,19 +37,25 @@ for the LM head. from __future__ import annotations +import builtins import logging import types -from typing import Any +from collections import deque +from pathlib import Path +from typing import Any, Unpack import torch from torch import Tensor from torch.nn import functional as F +from lerobot.configs import PreTrainedConfig from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.import_utils import require_package -from ..pi05.configuration_pi05 import PI05Config -from ..pi05.modeling_pi05 import PI05Policy +from ..pretrained import PreTrainedPolicy, T +from ..rtc.modeling_rtc import RTCProcessor from .configuration_pi052 import PI052Config +from .pi05_backbone import ActionSelectKwargs, PI05Pytorch, pad_vector, resize_with_pad_torch logger = logging.getLogger(__name__) @@ -335,7 +341,7 @@ def _compute_layer_ki( if mask_for_action.dtype != Q_action.dtype: mask_for_action = mask_for_action.to(dtype=Q_action.dtype) - from ..pi05.modeling_pi05 import sdpa_attention_forward # noqa: PLC0415 + from .pi05_backbone import sdpa_attention_forward # noqa: PLC0415 att_vlm, _ = sdpa_attention_forward( paligemma.model.language_model.layers[layer_idx].self_attn, @@ -387,7 +393,7 @@ def _paligemma_forward_ki( (VLM-only or action-only) defer back to the original forward — KI only matters when actions and VLM tokens are forwarded together. """ - from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415 + from .pi05_backbone import layernorm_forward # noqa: PLC0415 if adarms_cond is None: adarms_cond = [None, None] @@ -435,21 +441,38 @@ def _paligemma_forward_ki( return [outputs_embeds[0], outputs_embeds[1]], None -class PI052Policy(PI05Policy): - """π0.5 with the PaliGemma LM head re-enabled.""" +class PI052Policy(PreTrainedPolicy): + """π0.5 with the PaliGemma LM head re-enabled. + + Self-contained: the PI0.5 backbone (PaliGemmaWithExpertModel / PI05Pytorch) + is vendored in ``pi05_backbone.py`` and the PI05Policy wrapper logic is + inlined directly here, so this policy does not depend on or inherit from + ``lerobot.policies.pi05`` (which stays identical to ``main``). + """ config_class = PI052Config name = "pi052" def __init__(self, config: PI052Config, **kwargs: Any) -> None: - # Patch ops BEFORE the backbone is built (super().__init__ below - # constructs PaliGemmaWithExpertModel which instantiates the - # Gemma/Siglip layers we want to swap). Always-on — the patch - # is process-global / idempotent and degrades gracefully if - # liger-kernel is missing. + # Patch ops BEFORE the backbone is built (the backbone constructed + # below instantiates the Gemma/Siglip layers we want to swap). + # Always-on — the patch is process-global / idempotent and degrades + # gracefully if liger-kernel is missing. _enable_hf_kernels() - super().__init__(config, **kwargs) + # ---- inlined PI05Policy.__init__ ---------------------------------- + require_package("transformers", extra="pi") + super().__init__(config) + config.validate_features() + self.config = config + self.init_rtc_processor() + self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + self.model.to(config.device) + self.reset() + # ---- end inlined PI05Policy.__init__ ------------------------------ + # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and # freezes a few terminal layers when ``train_expert_only`` is # the (default) True. We re-enable the head if the user @@ -487,7 +510,11 @@ class PI052Policy(PI05Policy): def reset(self): """Reset action and high-level inference state.""" - super().reset() + # inlined PI05Policy.reset + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } self.last_subtasks = None self.last_subtasks_raw = None self.last_subtasks_source = None @@ -556,7 +583,7 @@ class PI052Policy(PI05Policy): and predict_actions_t is None and not getattr(self.config, "enable_fast_action_loss", False) ): - return super().forward(batch, reduction=reduction) + return self._pi05_flow_forward(batch, reduction=reduction) run_flow = ( self.config.flow_loss_weight > 0 @@ -687,7 +714,7 @@ class PI052Policy(PI05Policy): """ from lerobot.utils.constants import ACTION # noqa: PLC0415 - from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + from .pi05_backbone import make_att_2d_masks # noqa: PLC0415 # ---- preamble (mirrors PI05Pytorch.forward) ------------------ actions = self.prepare_action(batch) @@ -842,7 +869,7 @@ class PI052Policy(PI05Policy): Returns ``(text_loss, fast_loss)``. Either can be ``None`` if the caller doesn't want that head. """ - from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + from .pi05_backbone import make_att_2d_masks # noqa: PLC0415 images, img_masks = self._preprocess_images(batch) lang_tokens = batch[OBS_LANGUAGE_TOKENS] @@ -953,7 +980,7 @@ class PI052Policy(PI05Policy): ``input_ids[t+1]`` for next-token prediction). Returns ``{}`` when the batch has no supervised text positions. """ - from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + from .pi05_backbone import make_att_2d_masks # noqa: PLC0415 text_labels = batch.get("text_labels") if text_labels is None or not bool((text_labels != -100).any().item()): @@ -1089,7 +1116,7 @@ class PI052Policy(PI05Policy): current_att = prefix_att_masks generated: list[int] = [] - from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + from .pi05_backbone import make_att_2d_masks # noqa: PLC0415 backbone = self.model.paligemma_with_expert lm_head = backbone.paligemma.lm_head @@ -1397,3 +1424,413 @@ class PI052Policy(PI05Policy): choice = torch.multinomial(sorted_p, num_samples=1) return sorted_ix.gather(-1, choice).squeeze(-1) return torch.multinomial(probs, num_samples=1).squeeze(-1) + + # ------------------------------------------------------------------ + # Inlined from PI05Policy (vendored; pi052 does not inherit pi05). + # Kept verbatim except PI05Policy.forward -> _pi05_flow_forward (the + # flow-only fallback used by PI052Policy.forward on unannotated batches). + # ------------------------------------------------------------------ + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Load state dict (expects keys with "model." prefix) + try: + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + token=kwargs.get("token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight" + embed_tokens_key = ( + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ) + if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float() + print("Initialized PaliGemma lm_head from language token embeddings") + elif lm_head_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float() + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not load state dict: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self): + """Return policy parameters, optionally split into LR-scaled groups. + + When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head`` + and its tied ``embed_tokens`` are placed in their own param + group with ``lr = base_lr * lm_head_lr_scale``. The cosine + scheduler multiplies both groups by the same lambda each step, + so the ratio is preserved across decay. Default ``1.0`` = + return ``self.parameters()`` (back-compat with existing checkpoints + and configs). + """ + scale = float(getattr(self.config, "lm_head_lr_scale", 1.0)) + if scale == 1.0: + return self.parameters() + head_params: list[torch.nn.Parameter] = [] + other_params: list[torch.nn.Parameter] = [] + # Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` — + # boosting only the projection without the embedding pulls them + # apart and breaks the tie that PaliGemma was pre-trained with. + head_substrings = ( + "paligemma_with_expert.paligemma.lm_head.", + "paligemma_with_expert.paligemma.model.language_model.embed_tokens.", + ) + for name, p in self.named_parameters(): + if not p.requires_grad: + continue + if any(s in name for s in head_substrings): + head_params.append(p) + else: + other_params.append(p) + base_lr = float(self.config.optimizer_lr) + groups: list[dict[str, object]] = [] + if other_params: + groups.append({"params": other_params, "lr": base_lr, "name": "policy"}) + if head_params: + groups.append( + {"params": head_params, "lr": base_lr * scale, "name": "lm_head"} + ) + # Sanity: head_substrings must match at least one parameter, otherwise + # the scale silently does nothing — surface that fast. + if not head_params: + raise RuntimeError( + "lm_head_lr_scale != 1.0 but no parameters matched the LM-head " + "name patterns: " + f"{head_substrings!r}. Did the underlying PaliGemma module rename?" + ) + logging.info( + "PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over " + "%d head params + %d other params", + scale, + base_lr, + base_lr * scale, + len(head_params), + len(other_params), + ) + return groups + + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def _pi05_flow_forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + actions = self.prepare_action(batch) + + noise = self.model.sample_noise(actions.shape, actions.device) + time = self.model.sample_time(actions.shape[0], actions.device) + + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss_dict = { + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0.5 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pi052/pi05_backbone.py b/src/lerobot/policies/pi052/pi05_backbone.py new file mode 100644 index 000000000..e551ed4fa --- /dev/null +++ b/src/lerobot/policies/pi052/pi05_backbone.py @@ -0,0 +1,947 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import copy +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + +from lerobot.utils.import_utils import _transformers_available, require_package + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + + from ..pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, + ) +else: + CONFIG_MAPPING = None + modeling_gemma = None + PiGemmaForCausalLM = None + _gated_residual = None + layernorm_forward = None + PaliGemmaForConditionalGenerationWithPiGemma = None +from lerobot.configs import PreTrainedConfig +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + +from ..pretrained import PreTrainedPolicy, T +from ..rtc.modeling_rtc import RTCProcessor +from ..pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config + + +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + # Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU + alpha_t = torch.tensor(alpha, dtype=torch.float32) + beta_t = torch.tensor(beta, dtype=torch.float32) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)).to(device) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(0.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else 0.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +def sdpa_attention_forward( + module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, +): + """Drop-in for ``modeling_gemma.eager_attention_forward`` using + ``torch.nn.functional.scaled_dot_product_attention``. + + PyTorch SDPA picks the memory-efficient kernel for arbitrary additive + bias masks (the FA backend only accepts causal/sliding-window). On + H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory + than the eager softmax(QK^T)+matmul path. Mirrors eager's signature + and output shape (``(B, Lq, H, D)``) so call sites are unchanged. + """ + n_rep = module.num_key_value_groups + if n_rep > 1: + key = key.repeat_interleave(n_rep, dim=1) + value = value.repeat_interleave(n_rep, dim=1) + if attention_mask is not None and attention_mask.dtype != query.dtype: + attention_mask = attention_mask.to(dtype=query.dtype) + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout if module.training else 0.0, + is_causal=False, + scale=scaling, + ) + return attn_output.transpose(1, 2).contiguous(), None + + +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.model.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling + att_output, _ = sdpa_attention_forward( + paligemma.model.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = _gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI05.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, + freeze_vision_encoder: bool = False, + train_expert_only: bool = False, + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + self._set_requires_grad() + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Saves memory vs full float32; more memory than only 3 params. + params_to_keep_float32 = [ + "vision_tower", + "multi_modal_projector", + "lm_head", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def _set_requires_grad(self): + if self.freeze_vision_encoder: + self.paligemma.model.vision_tower.eval() + for param in self.paligemma.model.vision_tower.parameters(): + param.requires_grad = False + if self.train_expert_only: + self.paligemma.eval() + for param in self.paligemma.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.freeze_vision_encoder: + self.paligemma.model.vision_tower.eval() + if self.train_expert_only: + self.paligemma.eval() + + def embed_image(self, image: torch.Tensor): + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + # OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the + # Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base + # was trained in this regime, so scaling image features here over-scales them ~45x and + # breaks the pretrained vision-language alignment. Keep image features un-normalized. + features = image_outputs.pooler_output + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.model.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.model.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.model.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI05 PyTorch model.""" + + def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True], + precision=config.dtype, + image_size=config.image_resolution[0], + freeze_vision_encoder=config.freeze_vision_encoder, + train_expert_only=config.train_expert_only, + ) + + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) + + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI05Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI05Pytorch model") + + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + if dtype is not None: + result = result.to(dtype=dtype) + return result + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, tokens, masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(tokens): + # embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding + # (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT + # scale again here or text tokens get double-scaled (~45x) and break alignment. + return self.paligemma_with_expert.embed_language_tokens(tokens) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed timestep using sine-cosine positional encoding + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor: + """Do a full training forward pass and compute the loss.""" + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + if ( + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) + + # Selective AC: rely on the per-layer checkpoint inside + # ``PaliGemmaWithExpertModel.forward`` (which wraps each + # transformer block individually). The previous outer + # ``_apply_checkpoint(forward_func, ...)`` doubled up — it + # re-ran the full backbone forward during backward *and* each + # block's own checkpoint re-ran during that recompute. Pure + # waste with SDPA, which already streams attention activations. + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions( + self, + images, + img_masks, + tokens, + masks, + noise=None, + num_steps=None, + **kwargs: Unpack[ActionSelectKwargs], + ) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = tokens.shape[0] + device = tokens.device + + if noise is None: + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, + ) # Use config max_action_dim for internal processing + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d( + prefix_att_2d_masks, dtype=prefix_embs.dtype + ) + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + + x_t = noise + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) + + if self._rtc_enabled(): + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + + x_t = x_t + dt * v_t + + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) + + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d( + full_att_2d_masks, dtype=suffix_embs.dtype + ) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + past_key_values = copy.deepcopy(past_key_values) + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index d342cffaf..d9342eb24 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -268,8 +268,6 @@ class PI0FastPaliGemma(nn.Module): return features def embed_language_tokens(self, tokens: torch.Tensor): - # get_input_embeddings() is GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already - # multiplies by sqrt(hidden_size) internally, so no manual scaling is needed here. return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( diff --git a/tests/policies/pi052/test_pi052_sdpa_attention.py b/tests/policies/pi052/test_pi052_sdpa_attention.py index 808e80faf..218a5fa3e 100644 --- a/tests/policies/pi052/test_pi052_sdpa_attention.py +++ b/tests/policies/pi052/test_pi052_sdpa_attention.py @@ -33,7 +33,7 @@ pytest.importorskip("transformers") from transformers.models.gemma import modeling_gemma # noqa: E402 -from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402 +from lerobot.policies.pi052.pi05_backbone import ( # noqa: E402 make_att_2d_masks, sdpa_attention_forward, )