refactor(pi052): self-contained policy; revert pi0/pi05 to upstream main

The smolvla branch had modified the shared pi0/pi05 modeling + pi05 config to
support pi052 (SDPA attention, layernorm/lm_head handling, optimizer
foreach/fused/lm_head_lr_scale, embedding scaling). Decouple pi052 instead:

- Vendor the PI0.5 backbone (PaliGemmaWithExpertModel, PI05Pytorch, helpers)
  into pi052/pi05_backbone.py (verbatim copy, no PI05Policy).
- Flatten PI052Policy to subclass PreTrainedPolicy directly (no longer
  PI05Policy); inline the needed PI05Policy methods.
- Restore optimizer_foreach/fused + get_optimizer_preset on PI052Config.
- Revert pi0, pi0_fast, pi05 modeling and configuration_pi05 to origin/main
  (byte-identical), so the shared policies carry no smolvla modifications.

Behavior verified bit-exact on pepijn223/pi052_robocasa_full: embed_language_
tokens, predict_action_chunk, and the fused flow+text+FAST training loss are
identical before/after (max_abs_diff=0). pi052 tests pass (pre-existing
stale-name collection errors unchanged).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-04 19:59:27 +02:00
parent 8292548f0d
commit a594ad7969
8 changed files with 1484 additions and 204 deletions
+2 -4
View File
@@ -677,10 +677,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(lang_tokens):
# embed_language_tokens -> Gemma get_input_embeddings(), which is
# GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already multiplies by
# sqrt(hidden_size) internally. Do NOT scale again here (would double-scale text).
return self.paligemma_with_expert.embed_language_tokens(lang_tokens)
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
@@ -93,21 +93,6 @@ class PI05Config(PreTrainedConfig):
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
# tied `embed_tokens`) is the surface the LM head's first-token
# distribution depends on. With ``knowledge_insulation`` blocking
# action→VLM gradients, the LM head only sees gradients on text-CE
# samples — which can be a small fraction of the mix (e.g. ~45% in
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
# first-token distribution can drift back toward PaliGemma's
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
# the head pinned to fine-tuning targets without perturbing the
# backbone / vision tower / action expert. Default 1.0 = no change.
lm_head_lr_scale: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
@@ -167,8 +152,6 @@ class PI05Config(PreTrainedConfig):
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):
+59 -161
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -223,53 +233,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
return padded_images
def sdpa_attention_forward(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
):
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
``torch.nn.functional.scaled_dot_product_attention``.
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
bias masks (the FA backend only accepts causal/sliding-window). On
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
"""
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
if attention_mask is not None and attention_mask.dtype != query.dtype:
attention_mask = attention_mask.to(dtype=query.dtype)
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout if module.training else 0.0,
is_causal=False,
scale=scaling,
)
return attn_output.transpose(1, 2).contiguous(), None
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -291,14 +262,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
att_output, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -306,13 +279,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -444,7 +417,6 @@ class PaliGemmaWithExpertModel(
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -477,17 +449,13 @@ class PaliGemmaWithExpertModel(
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
# was trained in this regime, so scaling image features here over-scales them ~45x and
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
@@ -525,8 +493,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -536,36 +505,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -657,13 +629,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
if dtype is not None:
result = result.to(dtype=dtype)
return result
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
def sample_noise(self, shape, device):
return torch.normal(
@@ -704,10 +673,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(tokens):
# embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding
# (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT
# scale again here or text tokens get double-scaled (~45x) and break alignment.
return self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
@@ -794,22 +761,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Selective AC: rely on the per-layer checkpoint inside
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
# transformer block individually). The previous outer
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
# re-ran the full backbone forward during backward *and* each
# block's own checkpoint re-ran during that recompute. Pure
# waste with SDPA, which already streams attention activations.
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
@@ -853,9 +819,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
prefix_att_2d_masks, dtype=prefix_embs.dtype
)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -925,12 +889,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
full_att_2d_masks, dtype=suffix_embs.dtype
)
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -1065,16 +1027,6 @@ class PI05Policy(PreTrainedPolicy):
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = (
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
)
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
print("Initialized PaliGemma lm_head from language token embeddings")
elif lm_head_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
@@ -1168,62 +1120,8 @@ class PI05Policy(PreTrainedPolicy):
return fixed_state_dict
def get_optim_params(self):
"""Return policy parameters, optionally split into LR-scaled groups.
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
and its tied ``embed_tokens`` are placed in their own param
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
scheduler multiplies both groups by the same lambda each step,
so the ratio is preserved across decay. Default ``1.0`` =
return ``self.parameters()`` (back-compat with existing checkpoints
and configs).
"""
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
if scale == 1.0:
return self.parameters()
head_params: list[torch.nn.Parameter] = []
other_params: list[torch.nn.Parameter] = []
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
# boosting only the projection without the embedding pulls them
# apart and breaks the tie that PaliGemma was pre-trained with.
head_substrings = (
"paligemma_with_expert.paligemma.lm_head.",
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
)
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if any(s in name for s in head_substrings):
head_params.append(p)
else:
other_params.append(p)
base_lr = float(self.config.optimizer_lr)
groups: list[dict[str, object]] = []
if other_params:
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
if head_params:
groups.append(
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
)
# Sanity: head_substrings must match at least one parameter, otherwise
# the scale silently does nothing — surface that fast.
if not head_params:
raise RuntimeError(
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
"name patterns: "
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
)
logging.info(
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
"%d head params + %d other params",
scale,
base_lr,
base_lr * scale,
len(head_params),
len(other_params),
)
return groups
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""Reset internal state - called when environment resets."""
@@ -39,6 +39,7 @@ step, which the multi-rate ``PI052Runtime`` (in
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from lerobot.optim.optimizers import AdamWConfig
from ..pi05.configuration_pi05 import PI05Config
@@ -206,6 +207,24 @@ class PI052Config(PI05Config):
``DecodingError: The fields use_hf_kernels are not valid for
PI052Config`` (job 22164492). Remove in a future major bump."""
# Optimizer foreach/fused. pi052 carries these locally because the shared
# PI05Config (kept identical to upstream main) does not define them; the
# checkpoints we train serialize both keys into config.json, so they must
# be valid PI052Config fields and flow into the AdamW preset below.
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
+456 -19
View File
@@ -37,19 +37,25 @@ for the LM head.
from __future__ import annotations
import builtins
import logging
import types
from typing import Any
from collections import deque
from pathlib import Path
from typing import Any, Unpack
import torch
from torch import Tensor
from torch.nn import functional as F
from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.utils.import_utils import require_package
from ..pi05.configuration_pi05 import PI05Config
from ..pi05.modeling_pi05 import PI05Policy
from ..pretrained import PreTrainedPolicy, T
from ..rtc.modeling_rtc import RTCProcessor
from .configuration_pi052 import PI052Config
from .pi05_backbone import ActionSelectKwargs, PI05Pytorch, pad_vector, resize_with_pad_torch
logger = logging.getLogger(__name__)
@@ -335,7 +341,7 @@ def _compute_layer_ki(
if mask_for_action.dtype != Q_action.dtype:
mask_for_action = mask_for_action.to(dtype=Q_action.dtype)
from ..pi05.modeling_pi05 import sdpa_attention_forward # noqa: PLC0415
from .pi05_backbone import sdpa_attention_forward # noqa: PLC0415
att_vlm, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
@@ -387,7 +393,7 @@ def _paligemma_forward_ki(
(VLM-only or action-only) defer back to the original forward
KI only matters when actions and VLM tokens are forwarded together.
"""
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
from .pi05_backbone import layernorm_forward # noqa: PLC0415
if adarms_cond is None:
adarms_cond = [None, None]
@@ -435,21 +441,38 @@ def _paligemma_forward_ki(
return [outputs_embeds[0], outputs_embeds[1]], None
class PI052Policy(PI05Policy):
"""π0.5 with the PaliGemma LM head re-enabled."""
class PI052Policy(PreTrainedPolicy):
"""π0.5 with the PaliGemma LM head re-enabled.
Self-contained: the PI0.5 backbone (PaliGemmaWithExpertModel / PI05Pytorch)
is vendored in ``pi05_backbone.py`` and the PI05Policy wrapper logic is
inlined directly here, so this policy does not depend on or inherit from
``lerobot.policies.pi05`` (which stays identical to ``main``).
"""
config_class = PI052Config
name = "pi052"
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
# Patch ops BEFORE the backbone is built (super().__init__ below
# constructs PaliGemmaWithExpertModel which instantiates the
# Gemma/Siglip layers we want to swap). Always-on — the patch
# is process-global / idempotent and degrades gracefully if
# liger-kernel is missing.
# Patch ops BEFORE the backbone is built (the backbone constructed
# below instantiates the Gemma/Siglip layers we want to swap).
# Always-on — the patch is process-global / idempotent and degrades
# gracefully if liger-kernel is missing.
_enable_hf_kernels()
super().__init__(config, **kwargs)
# ---- inlined PI05Policy.__init__ ----------------------------------
require_package("transformers", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
self.reset()
# ---- end inlined PI05Policy.__init__ ------------------------------
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
# freezes a few terminal layers when ``train_expert_only`` is
# the (default) True. We re-enable the head if the user
@@ -487,7 +510,11 @@ class PI052Policy(PI05Policy):
def reset(self):
"""Reset action and high-level inference state."""
super().reset()
# inlined PI05Policy.reset
self._action_queue = deque(maxlen=self.config.n_action_steps)
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
self.last_subtasks = None
self.last_subtasks_raw = None
self.last_subtasks_source = None
@@ -556,7 +583,7 @@ class PI052Policy(PI05Policy):
and predict_actions_t is None
and not getattr(self.config, "enable_fast_action_loss", False)
):
return super().forward(batch, reduction=reduction)
return self._pi05_flow_forward(batch, reduction=reduction)
run_flow = (
self.config.flow_loss_weight > 0
@@ -687,7 +714,7 @@ class PI052Policy(PI05Policy):
"""
from lerobot.utils.constants import ACTION # noqa: PLC0415
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
from .pi05_backbone import make_att_2d_masks # noqa: PLC0415
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
actions = self.prepare_action(batch)
@@ -842,7 +869,7 @@ class PI052Policy(PI05Policy):
Returns ``(text_loss, fast_loss)``. Either can be ``None`` if
the caller doesn't want that head.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
from .pi05_backbone import make_att_2d_masks # noqa: PLC0415
images, img_masks = self._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
@@ -953,7 +980,7 @@ class PI052Policy(PI05Policy):
``input_ids[t+1]`` for next-token prediction). Returns ``{}``
when the batch has no supervised text positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
from .pi05_backbone import make_att_2d_masks # noqa: PLC0415
text_labels = batch.get("text_labels")
if text_labels is None or not bool((text_labels != -100).any().item()):
@@ -1089,7 +1116,7 @@ class PI052Policy(PI05Policy):
current_att = prefix_att_masks
generated: list[int] = []
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
from .pi05_backbone import make_att_2d_masks # noqa: PLC0415
backbone = self.model.paligemma_with_expert
lm_head = backbone.paligemma.lm_head
@@ -1397,3 +1424,413 @@ class PI052Policy(PI05Policy):
choice = torch.multinomial(sorted_p, num_samples=1)
return sorted_ix.gather(-1, choice).squeeze(-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
# ------------------------------------------------------------------
# Inlined from PI05Policy (vendored; pi052 does not inherit pi05).
# Kept verbatim except PI05Policy.forward -> _pi05_flow_forward (the
# flow-only fallback used by PI052Policy.forward on unannotated batches).
# ------------------------------------------------------------------
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"The PI05 model is a direct port of the OpenPI implementation. \n"
"This implementation follows the original OpenPI structure for compatibility. \n"
"Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise create default config
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
try:
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
from safetensors.torch import load_file
original_state_dict = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
for key, value in fixed_state_dict.items():
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
else:
remapped_state_dict[key] = value
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = (
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
)
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
print("Initialized PaliGemma lm_head from language token embeddings")
elif lm_head_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
return model
def _fix_pytorch_state_dict_keys(
self, state_dict, model_config
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
"""Fix state dict keys to match current model architecture."""
import re
fixed_state_dict = {}
for key, value in state_dict.items():
new_key = key
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
# For gemma expert layers
if re.match(
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
key,
):
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
continue
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
continue
# Handle MLP naming changes for pi05
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
if key.startswith("action_time_mlp_in."):
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
elif key.startswith("action_time_mlp_out."):
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
# Also handle state_proj which shouldn't exist in pi05
if key.startswith("state_proj."):
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
continue
# Handle vision tower embedding layer potential differences
if "patch_embedding" in key:
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
def get_optim_params(self):
"""Return policy parameters, optionally split into LR-scaled groups.
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
and its tied ``embed_tokens`` are placed in their own param
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
scheduler multiplies both groups by the same lambda each step,
so the ratio is preserved across decay. Default ``1.0`` =
return ``self.parameters()`` (back-compat with existing checkpoints
and configs).
"""
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
if scale == 1.0:
return self.parameters()
head_params: list[torch.nn.Parameter] = []
other_params: list[torch.nn.Parameter] = []
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
# boosting only the projection without the embedding pulls them
# apart and breaks the tie that PaliGemma was pre-trained with.
head_substrings = (
"paligemma_with_expert.paligemma.lm_head.",
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
)
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if any(s in name for s in head_substrings):
head_params.append(p)
else:
other_params.append(p)
base_lr = float(self.config.optimizer_lr)
groups: list[dict[str, object]] = []
if other_params:
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
if head_params:
groups.append(
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
)
# Sanity: head_substrings must match at least one parameter, otherwise
# the scale silently does nothing — surface that fast.
if not head_params:
raise RuntimeError(
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
"name patterns: "
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
)
logging.info(
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
"%d head params + %d other params",
scale,
base_lr,
base_lr * scale,
len(head_params),
len(other_params),
)
return groups
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
"""
images = []
img_masks = []
# Get device from model parameters
device = next(self.parameters()).device
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. "
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
# Ensure float32 dtype for consistency
if img.dtype != torch.float32:
img = img.to(torch.float32)
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
img = img.permute(0, 2, 3, 1)
# from openpi preprocess_observation_pytorch: Resize with padding if needed
if img.shape[1:3] != self.config.image_resolution:
img = resize_with_pad_torch(img, *self.config.image_resolution)
# Normalize from [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
images.append(img)
# Create mask (all ones for real images)
bsize = img.shape[0]
mask = torch.ones(bsize, dtype=torch.bool, device=device)
img_masks.append(mask)
# Create image features not present in the batch as fully 0 padded images
for _num_empty_cameras in range(len(missing_img_keys)):
img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
mask = torch.zeros_like(mask) # Mask is zero for empty cameras
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
def _pi05_flow_forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training.
Args:
batch: Training batch containing observations and actions.
reduction: How to reduce the loss. Options:
- "mean": Return scalar mean loss (default, backward compatible)
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
"""
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
noise = self.model.sample_noise(actions.shape, actions.device)
time = self.model.sample_time(actions.shape[0], actions.device)
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, :original_action_dim]
loss_dict = {
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0.5 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}
+947
View File
@@ -0,0 +1,947 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from ..pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OPENPI_ATTENTION_MASK_VALUE,
)
from ..pretrained import PreTrainedPolicy, T
from ..rtc.modeling_rtc import RTCProcessor
from ..pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
return torch.float32
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,)).to(device)
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] >= new_dim:
return vector
return F.pad(vector, (0, new_dim - vector.shape[-1]))
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images,
size=(resized_height, resized_width),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
return padded_images
def sdpa_attention_forward(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
):
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
``torch.nn.functional.scaled_dot_product_attention``.
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
bias masks (the FA backend only accepts causal/sliding-window). On
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
"""
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
if attention_mask is not None and attention_mask.dtype != query.dtype:
attention_mask = attention_mask.to(dtype=query.dtype)
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout if module.training else 0.0,
is_causal=False,
scale=scaling,
)
return attn_output.transpose(1, 2).contiguous(), None
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
att_output, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
class GemmaConfig: # see openpi `gemma.py: Config`
"""Configuration for Gemma model variants."""
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
self.width = width
self.depth = depth
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
"""Returns config for specified gemma variant."""
if variant == "gemma_300m":
return GemmaConfig(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
elif variant == "gemma_2b":
return GemmaConfig(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
else:
raise ValueError(f"Unknown variant: {variant}")
class PaliGemmaWithExpertModel(
nn.Module
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
"""PaliGemma model with action expert for PI05."""
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
# was trained in this regime, so scaling image features here over-scales them ~45x and
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True],
precision=config.dtype,
image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
if dtype is not None:
result = result.to(dtype=dtype)
return result
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
)
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, tokens, masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer."""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(tokens):
# embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding
# (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT
# scale again here or text tokens get double-scaled (~45x) and break alignment.
return self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed timestep using sine-cosine positional encoding
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=self.config.min_period,
max_period=self.config.max_period,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x)
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
"""Do a full training forward pass and compute the loss."""
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
# Selective AC: rely on the per-layer checkpoint inside
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
# transformer block individually). The previous outer
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
# re-ran the full backbone forward during backward *and* each
# block's own checkpoint re-ran during that recompute. Pure
# waste with SDPA, which already streams attention activations.
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
bsize = tokens.shape[0]
device = tokens.device
if noise is None:
# Sample noise with padded dimension as expected by action_in_proj
actions_shape = (
bsize,
self.config.chunk_size,
self.config.max_action_dim,
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
prefix_att_2d_masks, dtype=prefix_embs.dtype
)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
return x_t
def denoise_step(
self,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
full_att_2d_masks, dtype=suffix_embs.dtype
)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
@@ -268,8 +268,6 @@ class PI0FastPaliGemma(nn.Module):
return features
def embed_language_tokens(self, tokens: torch.Tensor):
# get_input_embeddings() is GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already
# multiplies by sqrt(hidden_size) internally, so no manual scaling is needed here.
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
@@ -33,7 +33,7 @@ pytest.importorskip("transformers")
from transformers.models.gemma import modeling_gemma # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
from lerobot.policies.pi052.pi05_backbone import ( # noqa: E402
make_att_2d_masks,
sdpa_attention_forward,
)