fix(pi052): knowledge insulation crashed on wrong _gated_residual import

_compute_layer_ki called modeling_gemma._gated_residual, but that
adaRMSNorm gated-residual helper is a lerobot helper in pi_gemma, not
part of HF transformers — so enabling knowledge_insulation crashed with
AttributeError on the first training step. Import _gated_residual from
pi_gemma, matching pi05's own layer code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 22:48:02 +02:00
parent b7317b6c29
commit ddf4bc2063
+7 -2
View File
@@ -170,6 +170,11 @@ def _compute_layer_ki(
):
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
# ``_gated_residual`` is a lerobot helper (adaRMSNorm gated residual),
# not part of HF's ``modeling_gemma``. pi05's own layer code imports
# it from ``pi_gemma`` — mirror that here.
from ..pi_gemma import _gated_residual # noqa: PLC0415
models = [paligemma.model.language_model, gemma_expert.model]
query_states, key_states, value_states, gates = [], [], [], []
@@ -247,13 +252,13 @@ def _compute_layer_ki(
if att.dtype != layer.self_attn.o_proj.weight.dtype:
att = att.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att[:, start:end])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
out_emb = _gated_residual(after_first, out_emb, gate)
outputs_embeds.append(out_emb)
start = end
return outputs_embeds