add knowledge insulation

This commit is contained in:
Jade Choghari
2026-01-26 09:14:39 +00:00
parent d0b6a66f34
commit 5e609426fd
6 changed files with 181 additions and 30 deletions
@@ -16,14 +16,13 @@ TEMPERATURE=0.9
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
# Run subtask annotation
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
--repo-id "$REPO_ID" \
--video-key observation.images.image \
--output-dir "$OUTPUT_DIR" \
--skip-existing \
--output-repo-id "jadechoghari/libero10-annotate" \
--batch-size "$BATCH_SIZE" \
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
# --repo-id "$REPO_ID" \
# --video-key observation.images.image \
# --output-dir "$OUTPUT_DIR" \
# --skip-existing \
# --output-repo-id "jadechoghari/libero10-annotate" \
# --batch-size "$BATCH_SIZE" \
# run synthetic data generation (all episodes processed)
# python examples/dataset/annotate_pgen.py \
# --repo-id "$REPO_ID" \
@@ -42,10 +41,10 @@ python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate
# add --push-to-hub flag
# efficient batch processing: 4 episodes at once
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
# --data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
# --output-dir "$OUTPUT_DIR" \
# --video-mode \
# --video-key observation.images.image \
# --video-batch-size "$BATCH_SIZE" \
# --sample-interval 5.0
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
--output-dir "$OUTPUT_DIR" \
--video-mode \
--video-key observation.images.image \
--video-batch-size "$BATCH_SIZE" \
--sample-interval 5.0
+107 -12
View File
@@ -51,6 +51,8 @@ from lerobot.utils.constants import (
OPENPI_ATTENTION_MASK_VALUE,
OBS_LANGUAGE_USER_PROMPT_TOKENS,
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
ACTION_TOKENS,
ACTION_TOKEN_MASK,
)
@@ -223,11 +225,51 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
"""
Compute a single transformer layer with fused attention across VLM and action expert.
This implements knowledge insulation
Forward pass:
- VLM tokens (backbone) and action tokens (expert) are processed together
- Action queries CAN attend to VLM keys/values (cross-attention)
- VLM queries attend to all tokens normally
Backward pass (KI):
- Gradients from action tokens MUST NOT flow into VLM keys/values
- This prevents the action expert's loss (flow-matching/diffusion) from
updating the VLM backbone parameters through the K/V projections
- VLM self-attention gradients remain unchanged
- Action self-attention gradients remain unchanged
Implementation:
- Split attention into two parts: VLM queries and action queries
- VLM queries use original (non-detached) K/V for full gradient flow
- Action queries use detached VLM K/V (stops gradients) + normal action K/V
- Results are concatenated to produce the same output as unified attention
Args:
layer_idx: Index of the current transformer layer
inputs_embeds: List of [vlm_embeds, action_embeds] tensors
attention_mask: 4D attention mask (B, 1, total_len, total_len)
position_ids: Position IDs for rotary embeddings
adarms_cond: Conditioning for adaptive RMS norm [vlm_cond, action_cond]
paligemma: The VLM (PaliGemma) model
gemma_expert: The action expert (Gemma) model
Returns:
outputs_embeds: List of [vlm_output, action_output] tensors
"""
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
# this tracks vlm (backbone) token length for knowledge insulation
# inputs_embeds[0] = vlm tokens, inputs_embeds[1] = action expert tokens
vlm_len = inputs_embeds[0].shape[1]
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
@@ -240,10 +282,14 @@ def compute_layer_complete(
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
# concat Q/K/V across VLM and action tokens
# (B, num_heads, vlm_len + action_len, head_dim)
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
# Apply rotary position embeddings
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
@@ -255,21 +301,70 @@ def compute_layer_complete(
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
# KNOWLEDGE INSULATION
# split queries into vlm (backbone) and action (expert) parts
Q_vlm = query_states[:, :, :vlm_len, :] # (B, num_heads, vlm_len, head_dim)
Q_action = query_states[:, :, vlm_len:, :] # (B, num_heads, action_len, head_dim)
# split K/V into vlm and action parts
K_vlm = key_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim)
K_action = key_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
V_vlm = value_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim)
V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
# create detached vlm K/V for action queries
# .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections
K_vlm_detached = K_vlm.detach()
V_vlm_detached = V_vlm.detach()
# K/V for VLM queries: use original (full gradient flow for VLM self-attention)
K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action]
V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action]
# K/V for action queries: detached VLM K/V + normal action K/V
# This implements the knowledge insulation: action queries can "see" VLM K/V
# in forward pass, but gradients are blocked in backward pass
K_for_action = torch.cat([K_vlm_detached, K_action], dim=2)
V_for_action = torch.cat([V_vlm_detached, V_action], dim=2)
# split attention mask for vlm and action queries
# attention_mask shape: (B, 1, total_len, total_len)
mask_for_vlm = attention_mask[:, :, :vlm_len, :] # (B, 1, vlm_len, total_len)
mask_for_action = attention_mask[:, :, vlm_len:, :] # (B, 1, action_len, total_len)
# compute attention for vlm queries (normal gradient flow)
att_output_vlm, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
Q_vlm,
K_for_vlm,
V_for_vlm,
mask_for_vlm,
scaling,
)
# Get head_dim from the current layer, not from the model
# compute attention for action queries (insulated from vlm K/V gradients)
att_output_action, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
Q_action,
K_for_action,
V_for_action,
mask_for_action,
scaling,
)
# concat attention outputs to match original unified attention output shape
# att_output shape after eager_attention_forward: (B, seq_len, num_heads * head_dim)
att_output = torch.cat([att_output_vlm, att_output_action], dim=1)
# get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
# process layer outputs (MLP, residuals, etc.)
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
@@ -282,7 +377,7 @@ def compute_layer_complete(
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
# convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
@@ -1657,8 +1752,8 @@ class PI05FullPolicy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ATTENTION_MASK}"]
action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"]
actions = self.prepare_action(batch)
@@ -84,6 +84,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompts = []
for i, user_prompt in enumerate(user_prompts):
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
full_prompts.append(full_prompt)
@@ -94,6 +95,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_USER_PROMPT,
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
@@ -168,6 +170,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the subtask from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of subtask strings, or None if the subtask key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
subtask = complementary_data.get("subtask")
if subtask is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(subtask, str):
return [subtask]
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
return subtask
return None
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
@@ -221,6 +249,22 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
if subtask is not None:
tokenized_subtask = self._tokenize_text(subtask)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_subtask = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_subtask.items()
}
# Add tokenized subtask to the observation
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(dtype=torch.bool)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
@@ -330,6 +374,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
# Add features for subtask tokens and attention mask if they don't already exist
if OBS_LANGUAGE_SUBTASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
if OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
return features
-2
View File
@@ -113,8 +113,6 @@ def update_policy(
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
else:
loss, output_dict = policy.forward(batch)
policy.select_action(batch)
breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
+3 -1
View File
@@ -29,7 +29,9 @@ OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt"
OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens"
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask"
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_TOKENS + ".attention_mask"
ACTION = "action"
ACTION_PREFIX = ACTION + "."
ACTION_TOKENS = ACTION + ".tokens"