add KI optional

This commit is contained in:
Jade Choghari
2026-02-02 15:58:47 +00:00
parent 092f4617ca
commit 6c94fcd1b1
2 changed files with 107 additions and 7 deletions
@@ -88,6 +88,12 @@ class PI05FullConfig(PreTrainedConfig):
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V)
# Loss weights (used when knowledge_insulation is enabled)
loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions)
loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss
loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
+101 -7
View File
@@ -222,9 +222,84 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
return padded_images
# Define the complete layer computation function for gradient checkpointing
# Define the complete layer computation function for gradient checkpointing (without knowledge insulation)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
"""Compute a single transformer layer with fused attention across VLM and action expert (no knowledge insulation)."""
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Define the complete layer computation function with knowledge insulation for gradient checkpointing
def compute_layer_complete_knowledge_insulation(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
"""
Compute a single transformer layer with fused attention across VLM and action expert.
@@ -439,12 +514,14 @@ class PaliGemmaWithExpertModel(
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
knowledge_insulation: bool = True,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.knowledge_insulation = knowledge_insulation
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -578,11 +655,16 @@ class PaliGemmaWithExpertModel(
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Select the appropriate layer computation function based on knowledge_insulation
layer_compute_fn = (
compute_layer_complete_knowledge_insulation if self.knowledge_insulation else compute_layer_complete
)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_compute_fn,
layer_idx,
inputs_embeds,
attention_mask,
@@ -594,7 +676,7 @@ class PaliGemmaWithExpertModel(
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
inputs_embeds = layer_compute_fn(
layer_idx,
inputs_embeds,
attention_mask,
@@ -655,6 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
knowledge_insulation=config.knowledge_insulation,
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -1113,11 +1196,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
# Compute weighted total loss
flow_loss_mean = flow_loss.mean()
action_ce_loss_mean = fast_loss.mean()
subtask_ce_loss_mean = subtask_loss.mean()
total_loss = (
self.config.loss_weight_flow * flow_loss_mean
+ self.config.loss_weight_action_ce * action_ce_loss_mean
+ self.config.loss_weight_subtask_ce * subtask_ce_loss_mean
)
return {
"flow_mse_loss": flow_loss.mean(),
"action_ce_loss": fast_loss.mean(),
"subtask_ce_loss": subtask_loss,
"loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights
"flow_mse_loss": flow_loss_mean,
"action_ce_loss": action_ce_loss_mean,
"subtask_ce_loss": subtask_ce_loss_mean,
"loss": total_loss,
}
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)