diff --git a/src/lerobot/policies/pi05_full/annotate/annotate_libero.sh b/src/lerobot/policies/pi05_full/annotate/annotate_libero.sh index eabc02075..cd2d61bec 100644 --- a/src/lerobot/policies/pi05_full/annotate/annotate_libero.sh +++ b/src/lerobot/policies/pi05_full/annotate/annotate_libero.sh @@ -16,14 +16,13 @@ TEMPERATURE=0.9 SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed) # Run subtask annotation -python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \ - --repo-id "$REPO_ID" \ - --video-key observation.images.image \ - --output-dir "$OUTPUT_DIR" \ - --skip-existing \ - --output-repo-id "jadechoghari/libero10-annotate" \ - --batch-size "$BATCH_SIZE" \ - +# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \ +# --repo-id "$REPO_ID" \ +# --video-key observation.images.image \ +# --output-dir "$OUTPUT_DIR" \ +# --skip-existing \ +# --output-repo-id "jadechoghari/libero10-annotate" \ +# --batch-size "$BATCH_SIZE" \ # run synthetic data generation (all episodes processed) # python examples/dataset/annotate_pgen.py \ # --repo-id "$REPO_ID" \ @@ -42,10 +41,10 @@ python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate # add --push-to-hub flag # efficient batch processing: 4 episodes at once -# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \ -# --data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \ -# --output-dir "$OUTPUT_DIR" \ -# --video-mode \ -# --video-key observation.images.image \ -# --video-batch-size "$BATCH_SIZE" \ -# --sample-interval 5.0 +python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \ + --data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \ + --output-dir "$OUTPUT_DIR" \ + --video-mode \ + --video-key observation.images.image \ + --video-batch-size "$BATCH_SIZE" \ + --sample-interval 5.0 diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index c9a8429cc..3db999fdf 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -51,6 +51,8 @@ from lerobot.utils.constants import ( OPENPI_ATTENTION_MASK_VALUE, OBS_LANGUAGE_USER_PROMPT_TOKENS, OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, ACTION_TOKENS, ACTION_TOKEN_MASK, ) @@ -223,11 +225,51 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): + """ + Compute a single transformer layer with fused attention across VLM and action expert. + + This implements knowledge insulation + + Forward pass: + - VLM tokens (backbone) and action tokens (expert) are processed together + - Action queries CAN attend to VLM keys/values (cross-attention) + - VLM queries attend to all tokens normally + + Backward pass (KI): + - Gradients from action tokens MUST NOT flow into VLM keys/values + - This prevents the action expert's loss (flow-matching/diffusion) from + updating the VLM backbone parameters through the K/V projections + - VLM self-attention gradients remain unchanged + - Action self-attention gradients remain unchanged + + Implementation: + - Split attention into two parts: VLM queries and action queries + - VLM queries use original (non-detached) K/V for full gradient flow + - Action queries use detached VLM K/V (stops gradients) + normal action K/V + - Results are concatenated to produce the same output as unified attention + + Args: + layer_idx: Index of the current transformer layer + inputs_embeds: List of [vlm_embeds, action_embeds] tensors + attention_mask: 4D attention mask (B, 1, total_len, total_len) + position_ids: Position IDs for rotary embeddings + adarms_cond: Conditioning for adaptive RMS norm [vlm_cond, action_cond] + paligemma: The VLM (PaliGemma) model + gemma_expert: The action expert (Gemma) model + + Returns: + outputs_embeds: List of [vlm_output, action_output] tensors + """ models = [paligemma.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] + + # this tracks vlm (backbone) token length for knowledge insulation + # inputs_embeds[0] = vlm tokens, inputs_embeds[1] = action expert tokens + vlm_len = inputs_embeds[0].shape[1] + for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 @@ -240,10 +282,14 @@ def compute_layer_complete( query_states.append(query_state) key_states.append(key_state) value_states.append(value_state) - # Concatenate and process attention + + # concat Q/K/V across VLM and action tokens + # (B, num_heads, vlm_len + action_len, head_dim) query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) + + # Apply rotary position embeddings dummy_tensor = torch.zeros( query_states.shape[0], query_states.shape[2], @@ -255,21 +301,70 @@ def compute_layer_complete( query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) + batch_size = query_states.shape[0] scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling - # Attention computation - att_output, _ = modeling_gemma.eager_attention_forward( + + # KNOWLEDGE INSULATION + # split queries into vlm (backbone) and action (expert) parts + Q_vlm = query_states[:, :, :vlm_len, :] # (B, num_heads, vlm_len, head_dim) + Q_action = query_states[:, :, vlm_len:, :] # (B, num_heads, action_len, head_dim) + + # split K/V into vlm and action parts + K_vlm = key_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim) + K_action = key_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim) + V_vlm = value_states[:, :, :vlm_len, :] # (B, num_kv_heads, vlm_len, head_dim) + V_action = value_states[:, :, vlm_len:, :] # (B, num_kv_heads, action_len, head_dim) + + # create detached vlm K/V for action queries + # .detach() stops gradient flow: action loss won't backprop into VLM's K/V projections + K_vlm_detached = K_vlm.detach() + V_vlm_detached = V_vlm.detach() + + # K/V for VLM queries: use original (full gradient flow for VLM self-attention) + K_for_vlm = key_states # Full concatenated K: [K_vlm, K_action] + V_for_vlm = value_states # Full concatenated V: [V_vlm, V_action] + + # K/V for action queries: detached VLM K/V + normal action K/V + # This implements the knowledge insulation: action queries can "see" VLM K/V + # in forward pass, but gradients are blocked in backward pass + K_for_action = torch.cat([K_vlm_detached, K_action], dim=2) + V_for_action = torch.cat([V_vlm_detached, V_action], dim=2) + + # split attention mask for vlm and action queries + # attention_mask shape: (B, 1, total_len, total_len) + mask_for_vlm = attention_mask[:, :, :vlm_len, :] # (B, 1, vlm_len, total_len) + mask_for_action = attention_mask[:, :, vlm_len:, :] # (B, 1, action_len, total_len) + + # compute attention for vlm queries (normal gradient flow) + att_output_vlm, _ = modeling_gemma.eager_attention_forward( paligemma.language_model.layers[layer_idx].self_attn, - query_states, - key_states, - value_states, - attention_mask, + Q_vlm, + K_for_vlm, + V_for_vlm, + mask_for_vlm, scaling, ) - # Get head_dim from the current layer, not from the model + + # compute attention for action queries (insulated from vlm K/V gradients) + att_output_action, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + Q_action, + K_for_action, + V_for_action, + mask_for_action, + scaling, + ) + + # concat attention outputs to match original unified attention output shape + # att_output shape after eager_attention_forward: (B, seq_len, num_heads * head_dim) + att_output = torch.cat([att_output_vlm, att_output_action], dim=1) + + # get head_dim from the current layer, not from the model head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) - # Process layer outputs + + # process layer outputs (MLP, residuals, etc.) outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): @@ -282,7 +377,7 @@ def compute_layer_complete( out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 after_first_residual = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + # convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) @@ -1657,8 +1752,8 @@ class PI05FullPolicy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"] - subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ATTENTION_MASK}"] action_tokens, action_masks = batch[f"{ACTION_TOKENS}"], batch[f"{ACTION_TOKEN_MASK}"] actions = self.prepare_action(batch) diff --git a/src/lerobot/policies/pi05_full/processor_pi05.py b/src/lerobot/policies/pi05_full/processor_pi05.py index 08571da48..80059e9c9 100644 --- a/src/lerobot/policies/pi05_full/processor_pi05.py +++ b/src/lerobot/policies/pi05_full/processor_pi05.py @@ -84,6 +84,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): full_prompts = [] for i, user_prompt in enumerate(user_prompts): cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ") + cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) state_str = " ".join(map(str, discretized_states[i])) full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" full_prompts.append(full_prompt) @@ -94,6 +95,7 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): full_commands = [] for i, command in enumerate(commands): cleaned_text = command.strip().replace("_", " ").replace("\n", " ") + cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) full_command = f"Subtask: {cleaned_text};\n" full_commands.append(full_command) diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index eedd9ec4c..c70773eee 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -34,6 +34,8 @@ from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_USER_PROMPT, OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK, @@ -168,6 +170,32 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get("subtask") + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary. @@ -221,6 +249,22 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"] new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + tokenized_subtask = self._tokenize_text(subtask) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_subtask = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask.items() + } + + # Add tokenized subtask to the observation + new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(dtype=torch.bool) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: @@ -330,6 +374,17 @@ class TokenizerProcessorStep(ObservationProcessorStep): type=FeatureType.LANGUAGE, shape=(self.max_length,) ) + # Add features for subtask tokens and attention mask if they don't already exist + if OBS_LANGUAGE_SUBTASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + if OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + return features diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 483b62a67..51554388e 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -113,8 +113,6 @@ def update_policy( output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"] else: loss, output_dict = policy.forward(batch) - policy.select_action(batch) - breakpoint() # TODO(rcadene): policy.unnormalize_outputs(out_dict) diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index b52a1b80d..09ab250f3 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -29,7 +29,9 @@ OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" OBS_LANGUAGE_USER_PROMPT = OBS_STR + ".user_prompt" OBS_LANGUAGE_USER_PROMPT_TOKENS = OBS_LANGUAGE_USER_PROMPT + ".tokens" OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_USER_PROMPT_TOKENS + ".attention_mask" - +OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" +OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_TOKENS + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "." ACTION_TOKENS = ACTION + ".tokens"