mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix(pi052): knowledge insulation crashed on wrong _gated_residual import
_compute_layer_ki called modeling_gemma._gated_residual, but that adaRMSNorm gated-residual helper is a lerobot helper in pi_gemma, not part of HF transformers — so enabling knowledge_insulation crashed with AttributeError on the first training step. Import _gated_residual from pi_gemma, matching pi05's own layer code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -170,6 +170,11 @@ def _compute_layer_ki(
|
|||||||
):
|
):
|
||||||
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
|
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
|
||||||
|
|
||||||
|
# ``_gated_residual`` is a lerobot helper (adaRMSNorm gated residual),
|
||||||
|
# not part of HF's ``modeling_gemma``. pi05's own layer code imports
|
||||||
|
# it from ``pi_gemma`` — mirror that here.
|
||||||
|
from ..pi_gemma import _gated_residual # noqa: PLC0415
|
||||||
|
|
||||||
models = [paligemma.model.language_model, gemma_expert.model]
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states, key_states, value_states, gates = [], [], [], []
|
query_states, key_states, value_states, gates = [], [], [], []
|
||||||
|
|
||||||
@@ -247,13 +252,13 @@ def _compute_layer_ki(
|
|||||||
if att.dtype != layer.self_attn.o_proj.weight.dtype:
|
if att.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
att = att.to(layer.self_attn.o_proj.weight.dtype)
|
att = att.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
out_emb = layer.self_attn.o_proj(att[:, start:end])
|
out_emb = layer.self_attn.o_proj(att[:, start:end])
|
||||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||||
after_first = out_emb.clone()
|
after_first = out_emb.clone()
|
||||||
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
|
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
out_emb = layer.mlp(out_emb)
|
out_emb = layer.mlp(out_emb)
|
||||||
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
|
out_emb = _gated_residual(after_first, out_emb, gate)
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
start = end
|
start = end
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|||||||
Reference in New Issue
Block a user