revert .clone

This commit is contained in:
Jade Choghari
2026-01-27 16:00:40 +00:00
parent 2bf6359d24
commit 6a6912ec37
@@ -375,8 +375,7 @@ def compute_layer_complete(
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
# Store reference instead of clone - we need original for second residual
after_first_residual = out_emb
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
# convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16: