This commit is contained in:
Jade Choghari
2026-01-26 09:19:14 +00:00
parent 5e609426fd
commit 4c694e20c7
@@ -317,16 +317,16 @@ def compute_layer_complete(
V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
# create detached vlm K/V for action queries
# .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections
# .detach() stops gradient flow: action loss wont backprop into VLM's K/V projections
K_vlm_detached = K_vlm.detach()
V_vlm_detached = V_vlm.detach()
# K/V for VLM queries: use original (full gradient flow for VLM self-attention)
K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action]
V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action]
K_for_vlm = key_states # [K_vlm, K_action]
V_for_vlm = value_states # [V_vlm, V_action]
# K/V for action queries: detached VLM K/V + normal action K/V
# This implements the knowledge insulation: action queries can "see" VLM K/V
# knowledge insulation: action queries can "see" VLM K/V
# in forward pass, but gradients are blocked in backward pass
K_for_action = torch.cat([K_vlm_detached, K_action], dim=2)
V_for_action = torch.cat([V_vlm_detached, V_action], dim=2)