diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index 3db999fdf..8a00215f0 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -317,16 +317,16 @@ def compute_layer_complete( V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim) # create detached vlm K/V for action queries - # .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections + # .detach() stops gradient flow: action loss wont backprop into VLM's K/V projections K_vlm_detached = K_vlm.detach() V_vlm_detached = V_vlm.detach() # K/V for VLM queries: use original (full gradient flow for VLM self-attention) - K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action] - V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action] + K_for_vlm = key_states # [K_vlm, K_action] + V_for_vlm = value_states # [V_vlm, V_action] # K/V for action queries: detached VLM K/V + normal action K/V - # This implements the knowledge insulation: action queries can "see" VLM K/V + # knowledge insulation: action queries can "see" VLM K/V # in forward pass, but gradients are blocked in backward pass K_for_action = torch.cat([K_vlm_detached, K_action], dim=2) V_for_action = torch.cat([V_vlm_detached, V_action], dim=2)