mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
add knowledge insulation
This commit is contained in:
@@ -16,14 +16,13 @@ TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.image \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--skip-existing \
|
||||
--output-repo-id "jadechoghari/libero10-annotate" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --video-key observation.images.image \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --skip-existing \
|
||||
# --output-repo-id "jadechoghari/libero10-annotate" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
@@ -42,10 +41,10 @@ python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
# --data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.image \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 5.0
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--video-mode \
|
||||
--video-key observation.images.image \
|
||||
--video-batch-size "$BATCH_SIZE" \
|
||||
--sample-interval 5.0
|
||||
|
||||
@@ -51,6 +51,8 @@ from lerobot.utils.constants import (
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
ACTION_TOKENS,
|
||||
ACTION_TOKEN_MASK,
|
||||
)
|
||||
@@ -223,11 +225,51 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
"""
|
||||
Compute a single transformer layer with fused attention across VLM and action expert.
|
||||
|
||||
This implements knowledge insulation
|
||||
|
||||
Forward pass:
|
||||
- VLM tokens (backbone) and action tokens (expert) are processed together
|
||||
- Action queries CAN attend to VLM keys/values (cross-attention)
|
||||
- VLM queries attend to all tokens normally
|
||||
|
||||
Backward pass (KI):
|
||||
- Gradients from action tokens MUST NOT flow into VLM keys/values
|
||||
- This prevents the action expert's loss (flow-matching/diffusion) from
|
||||
updating the VLM backbone parameters through the K/V projections
|
||||
- VLM self-attention gradients remain unchanged
|
||||
- Action self-attention gradients remain unchanged
|
||||
|
||||
Implementation:
|
||||
- Split attention into two parts: VLM queries and action queries
|
||||
- VLM queries use original (non-detached) K/V for full gradient flow
|
||||
- Action queries use detached VLM K/V (stops gradients) + normal action K/V
|
||||
- Results are concatenated to produce the same output as unified attention
|
||||
|
||||
Args:
|
||||
layer_idx: Index of the current transformer layer
|
||||
inputs_embeds: List of [vlm_embeds, action_embeds] tensors
|
||||
attention_mask: 4D attention mask (B, 1, total_len, total_len)
|
||||
position_ids: Position IDs for rotary embeddings
|
||||
adarms_cond: Conditioning for adaptive RMS norm [vlm_cond, action_cond]
|
||||
paligemma: The VLM (PaliGemma) model
|
||||
gemma_expert: The action expert (Gemma) model
|
||||
|
||||
Returns:
|
||||
outputs_embeds: List of [vlm_output, action_output] tensors
|
||||
"""
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
|
||||
# this tracks vlm (backbone) token length for knowledge insulation
|
||||
# inputs_embeds[0] = vlm tokens, inputs_embeds[1] = action expert tokens
|
||||
vlm_len = inputs_embeds[0].shape[1]
|
||||
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
@@ -240,10 +282,14 @@ def compute_layer_complete(
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
# Concatenate and process attention
|
||||
|
||||
# concat Q/K/V across VLM and action tokens
|
||||
# (B, num_heads, vlm_len + action_len, head_dim)
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
@@ -255,21 +301,70 @@ def compute_layer_complete(
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
|
||||
# KNOWLEDGE INSULATION
|
||||
# split queries into vlm (backbone) and action (expert) parts
|
||||
Q_vlm = query_states[:, :, :vlm_len, :] # (B, num_heads, vlm_len, head_dim)
|
||||
Q_action = query_states[:, :, vlm_len:, :] # (B, num_heads, action_len, head_dim)
|
||||
|
||||
# split K/V into vlm and action parts
|
||||
K_vlm = key_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim)
|
||||
K_action = key_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
|
||||
V_vlm = value_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim)
|
||||
V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim)
|
||||
|
||||
# create detached vlm K/V for action queries
|
||||
# .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections
|
||||
K_vlm_detached = K_vlm.detach()
|
||||
V_vlm_detached = V_vlm.detach()
|
||||
|
||||
# K/V for VLM queries: use original (full gradient flow for VLM self-attention)
|
||||
K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action]
|
||||
V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action]
|
||||
|
||||
# K/V for action queries: detached VLM K/V + normal action K/V
|
||||
# This implements the knowledge insulation: action queries can "see" VLM K/V
|
||||
# in forward pass, but gradients are blocked in backward pass
|
||||
K_for_action = torch.cat([K_vlm_detached, K_action], dim=2)
|
||||
V_for_action = torch.cat([V_vlm_detached, V_action], dim=2)
|
||||
|
||||
# split attention mask for vlm and action queries
|
||||
# attention_mask shape: (B, 1, total_len, total_len)
|
||||
mask_for_vlm = attention_mask[:, :, :vlm_len, :] # (B, 1, vlm_len, total_len)
|
||||
mask_for_action = attention_mask[:, :, vlm_len:, :] # (B, 1, action_len, total_len)
|
||||
|
||||
# compute attention for vlm queries (normal gradient flow)
|
||||
att_output_vlm, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
Q_vlm,
|
||||
K_for_vlm,
|
||||
V_for_vlm,
|
||||
mask_for_vlm,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
|
||||
# compute attention for action queries (insulated from vlm K/V gradients)
|
||||
att_output_action, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
Q_action,
|
||||
K_for_action,
|
||||
V_for_action,
|
||||
mask_for_action,
|
||||
scaling,
|
||||
)
|
||||
|
||||
# concat attention outputs to match original unified attention output shape
|
||||
# att_output shape after eager_attention_forward: (B, seq_len, num_heads * head_dim)
|
||||
att_output = torch.cat([att_output_vlm, att_output_action], dim=1)
|
||||
|
||||
# get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
|
||||
# process layer outputs (MLP, residuals, etc.)
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
@@ -282,7 +377,7 @@ def compute_layer_complete(
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
# convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
@@ -1657,8 +1752,8 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ATTENTION_MASK}"]
|
||||
action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"]
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
@@ -84,6 +84,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
full_prompts = []
|
||||
for i, user_prompt in enumerate(user_prompts):
|
||||
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
full_prompts.append(full_prompt)
|
||||
@@ -94,6 +95,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
|
||||
|
||||
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
@@ -168,6 +170,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
|
||||
@@ -221,6 +249,22 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -330,6 +374,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for subtask tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_SUBTASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
@@ -113,8 +113,6 @@ def update_policy(
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
policy.select_action(batch)
|
||||
breakpoint()
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
|
||||
@@ -29,7 +29,9 @@ OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt"
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens"
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask"
|
||||
|
||||
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_TOKENS + ".attention_mask"
|
||||
ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
|
||||
Reference in New Issue
Block a user