mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
comments
This commit is contained in:
@@ -317,16 +317,16 @@ def compute_layer_complete(
|
||||
V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
|
||||
|
||||
# create detached vlm K/V for action queries
|
||||
# .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections
|
||||
# .detach() stops gradient flow: action loss wont backprop into VLM's K/V projections
|
||||
K_vlm_detached = K_vlm.detach()
|
||||
V_vlm_detached = V_vlm.detach()
|
||||
|
||||
# K/V for VLM queries: use original (full gradient flow for VLM self-attention)
|
||||
K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action]
|
||||
V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action]
|
||||
K_for_vlm = key_states # [K_vlm, K_action]
|
||||
V_for_vlm = value_states # [V_vlm, V_action]
|
||||
|
||||
# K/V for action queries: detached VLM K/V + normal action K/V
|
||||
# This implements the knowledge insulation: action queries can "see" VLM K/V
|
||||
# knowledge insulation: action queries can "see" VLM K/V
|
||||
# in forward pass, but gradients are blocked in backward pass
|
||||
K_for_action = torch.cat([K_vlm_detached, K_action], dim=2)
|
||||
V_for_action = torch.cat([V_vlm_detached, V_action], dim=2)
|
||||
|
||||
Reference in New Issue
Block a user