From 6c94fcd1b17c9c37556fcd4e387bbd8f6db5ba0e Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 2 Feb 2026 15:58:47 +0000 Subject: [PATCH] add KI optional --- .../policies/pi05_full/configuration_pi05.py | 6 + .../policies/pi05_full/modeling_pi05.py | 108 ++++++++++++++++-- 2 files changed, 107 insertions(+), 7 deletions(-) diff --git a/src/lerobot/policies/pi05_full/configuration_pi05.py b/src/lerobot/policies/pi05_full/configuration_pi05.py index a95645220..744854521 100644 --- a/src/lerobot/policies/pi05_full/configuration_pi05.py +++ b/src/lerobot/policies/pi05_full/configuration_pi05.py @@ -88,6 +88,12 @@ class PI05FullConfig(PreTrainedConfig): # Finetuning settings freeze_vision_encoder: bool = False # Freeze only the vision encoder train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V) + + # Loss weights (used when knowledge_insulation is enabled) + loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions) + loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss + loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss # Optimizer settings: see openpi `AdamW` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index a13767794..ca66becb7 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -222,9 +222,84 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images -# Define the complete layer computation function for gradient checkpointing +# Define the complete layer computation function for gradient checkpointing (without knowledge insulation) def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + """Compute a single transformer layer with fused attention across VLM and action expert (no knowledge insulation).""" + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +# Define the complete layer computation function with knowledge insulation for gradient checkpointing +def compute_layer_complete_knowledge_insulation( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): """ Compute a single transformer layer with fused attention across VLM and action expert. @@ -439,12 +514,14 @@ class PaliGemmaWithExpertModel( image_size: int = DEFAULT_IMAGE_SIZE, freeze_vision_encoder: bool = False, train_expert_only: bool = False, + knowledge_insulation: bool = True, ): if use_adarms is None: use_adarms = [False, False] super().__init__() self.freeze_vision_encoder = freeze_vision_encoder self.train_expert_only = train_expert_only + self.knowledge_insulation = knowledge_insulation vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 @@ -578,11 +655,16 @@ class PaliGemmaWithExpertModel( and self.training ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + # Select the appropriate layer computation function based on knowledge_insulation + layer_compute_fn = ( + compute_layer_complete_knowledge_insulation if self.knowledge_insulation else compute_layer_complete + ) + # Process all layers with gradient checkpointing if enabled for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( - compute_layer_complete, + layer_compute_fn, layer_idx, inputs_embeds, attention_mask, @@ -594,7 +676,7 @@ class PaliGemmaWithExpertModel( gemma_expert=self.gemma_expert, ) else: - inputs_embeds = compute_layer_complete( + inputs_embeds = layer_compute_fn( layer_idx, inputs_embeds, attention_mask, @@ -655,6 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` image_size=config.image_resolution[0], freeze_vision_encoder=config.freeze_vision_encoder, train_expert_only=config.train_expert_only, + knowledge_insulation=config.knowledge_insulation, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) @@ -1113,11 +1196,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) flow_loss = F.mse_loss(u_t, v_t, reduction="none") + # Compute weighted total loss + flow_loss_mean = flow_loss.mean() + action_ce_loss_mean = fast_loss.mean() + subtask_ce_loss_mean = subtask_loss.mean() + + total_loss = ( + self.config.loss_weight_flow * flow_loss_mean + + self.config.loss_weight_action_ce * action_ce_loss_mean + + self.config.loss_weight_subtask_ce * subtask_ce_loss_mean + ) + return { - "flow_mse_loss": flow_loss.mean(), - "action_ce_loss": fast_loss.mean(), - "subtask_ce_loss": subtask_loss, - "loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights + "flow_mse_loss": flow_loss_mean, + "action_ce_loss": action_ce_loss_mean, + "subtask_ce_loss": subtask_ce_loss_mean, + "loss": total_loss, } @torch.no_grad() # see openpi `sample_actions` (slightly adapted)