From ddf4bc2063c4888809cbe1cfd9e2773d42fb1d45 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 19 May 2026 22:48:02 +0200 Subject: [PATCH] fix(pi052): knowledge insulation crashed on wrong _gated_residual import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _compute_layer_ki called modeling_gemma._gated_residual, but that adaRMSNorm gated-residual helper is a lerobot helper in pi_gemma, not part of HF transformers — so enabling knowledge_insulation crashed with AttributeError on the first training step. Import _gated_residual from pi_gemma, matching pi05's own layer code. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/pi052/modeling_pi052.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 1819abd22..a8e2103fa 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -170,6 +170,11 @@ def _compute_layer_ki( ): from transformers.models.gemma import modeling_gemma # noqa: PLC0415 + # ``_gated_residual`` is a lerobot helper (adaRMSNorm gated residual), + # not part of HF's ``modeling_gemma``. pi05's own layer code imports + # it from ``pi_gemma`` — mirror that here. + from ..pi_gemma import _gated_residual # noqa: PLC0415 + models = [paligemma.model.language_model, gemma_expert.model] query_states, key_states, value_states, gates = [], [], [], [] @@ -247,13 +252,13 @@ def _compute_layer_ki( if att.dtype != layer.self_attn.o_proj.weight.dtype: att = att.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att[:, start:end]) - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i]) if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) - out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first, out_emb, gate) outputs_embeds.append(out_emb) start = end return outputs_embeds