mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
add KI optional
This commit is contained in:
@@ -88,6 +88,12 @@ class PI05FullConfig(PreTrainedConfig):
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V)
|
||||
|
||||
# Loss weights (used when knowledge_insulation is enabled)
|
||||
loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions)
|
||||
loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss
|
||||
loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
|
||||
@@ -222,9 +222,84 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
return padded_images
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
# Define the complete layer computation function for gradient checkpointing (without knowledge insulation)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
"""Compute a single transformer layer with fused attention across VLM and action expert (no knowledge insulation)."""
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
|
||||
|
||||
# Define the complete layer computation function with knowledge insulation for gradient checkpointing
|
||||
def compute_layer_complete_knowledge_insulation(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
"""
|
||||
Compute a single transformer layer with fused attention across VLM and action expert.
|
||||
@@ -439,12 +514,14 @@ class PaliGemmaWithExpertModel(
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
freeze_vision_encoder: bool = False,
|
||||
train_expert_only: bool = False,
|
||||
knowledge_insulation: bool = True,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.knowledge_insulation = knowledge_insulation
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
@@ -578,11 +655,16 @@ class PaliGemmaWithExpertModel(
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Select the appropriate layer computation function based on knowledge_insulation
|
||||
layer_compute_fn = (
|
||||
compute_layer_complete_knowledge_insulation if self.knowledge_insulation else compute_layer_complete
|
||||
)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_compute_fn,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
@@ -594,7 +676,7 @@ class PaliGemmaWithExpertModel(
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
inputs_embeds = layer_compute_fn(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
@@ -655,6 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
image_size=config.image_resolution[0],
|
||||
freeze_vision_encoder=config.freeze_vision_encoder,
|
||||
train_expert_only=config.train_expert_only,
|
||||
knowledge_insulation=config.knowledge_insulation,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -1113,11 +1196,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# Compute weighted total loss
|
||||
flow_loss_mean = flow_loss.mean()
|
||||
action_ce_loss_mean = fast_loss.mean()
|
||||
subtask_ce_loss_mean = subtask_loss.mean()
|
||||
|
||||
total_loss = (
|
||||
self.config.loss_weight_flow * flow_loss_mean
|
||||
+ self.config.loss_weight_action_ce * action_ce_loss_mean
|
||||
+ self.config.loss_weight_subtask_ce * subtask_ce_loss_mean
|
||||
)
|
||||
|
||||
return {
|
||||
"flow_mse_loss": flow_loss.mean(),
|
||||
"action_ce_loss": fast_loss.mean(),
|
||||
"subtask_ce_loss": subtask_loss,
|
||||
"loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights
|
||||
"flow_mse_loss": flow_loss_mean,
|
||||
"action_ce_loss": action_ce_loss_mean,
|
||||
"subtask_ce_loss": subtask_ce_loss_mean,
|
||||
"loss": total_loss,
|
||||
}
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
|
||||
Reference in New Issue
Block a user