diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 1819abd22..a8e2103fa 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -170,6 +170,11 @@ def _compute_layer_ki( ): from transformers.models.gemma import modeling_gemma # noqa: PLC0415 + # ``_gated_residual`` is a lerobot helper (adaRMSNorm gated residual), + # not part of HF's ``modeling_gemma``. pi05's own layer code imports + # it from ``pi_gemma`` — mirror that here. + from ..pi_gemma import _gated_residual # noqa: PLC0415 + models = [paligemma.model.language_model, gemma_expert.model] query_states, key_states, value_states, gates = [], [], [], [] @@ -247,13 +252,13 @@ def _compute_layer_ki( if att.dtype != layer.self_attn.o_proj.weight.dtype: att = att.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att[:, start:end]) - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i]) if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) - out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first, out_emb, gate) outputs_embeds.append(out_emb) start = end return outputs_embeds