From 72b0af4ed7f0cea37990324453b11b5a861700a1 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Wed, 14 Jan 2026 14:52:32 +0000 Subject: [PATCH] add three losses: flow_mse, subtask_ce, action_ce --- .../policies/pi05_full/modeling_pi05.py | 169 +++++++++++++----- 1 file changed, 129 insertions(+), 40 deletions(-) diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index 7216c63b3..0348e6550 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -730,9 +730,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` embs = [] pad_masks = [] att_mask_segments = [] # Store info about each segment for custom mask creation - total_T_images = 0 - num_subtask_embs = 0 - num_fast_embs = 0 + image_len = 0 + language_len = 0 + subtask_len = 0 + fast_len = 0 # Process images for img, img_mask in zip(images, img_masks, strict=True): @@ -746,7 +747,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) att_mask_segments.append(('image', num_img_embs)) - total_T_images += num_img_embs + image_len += num_img_embs # Process language instruction tokens def lang_embed_func(tokens): @@ -774,8 +775,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # create subtask pad masks (non-zero tokens are valid) pad_masks.append(subtask_masks) - num_subtask_embs = subtask_emb.shape[1] - att_mask_segments.append(('subtask', num_subtask_embs)) + subtask_len = subtask_emb.shape[1] + att_mask_segments.append(('subtask', subtask_len)) # Process FAST action tokens (discrete token IDs) if fast_action_tokens is not None: @@ -788,9 +789,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) embs.append(fast_action_emb) - num_fast_embs = fast_action_tokens.shape[1] + fast_len = fast_action_tokens.shape[1] pad_masks.append(fast_action_masks) - att_mask_segments.append(("fast", num_fast_embs)) + att_mask_segments.append(("fast", fast_len)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) @@ -811,7 +812,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # max_display_tokens=512 # Limit display for very long sequences # ) - return embs, pad_masks, att_masks + return embs, pad_masks, att_masks, image_len def embed_suffix(self, noisy_actions, timestep): """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" @@ -872,29 +873,62 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks) + prefix_embs, prefix_pad_masks, prefix_att_masks, image_len = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) #TODO jadechoghari # this attention part should be reworked - breakpoint() + suffix_att_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + # prefix_att_masks: shape [B, prefix_len, prefix_len] + # suffix_att_masks: shape [B, suffix_len, suffix_len] + # get: image_len, language_len, subtask_len, fast_len + language_len = high_level_task_tokens.shape[1] + subtask_len = subtask_tokens.shape[1] + fast_len = action_tokens.shape[1] + + # create the correct attention mask for combining the prefix and suffix + # suffix embeddings from the action expert embeddings attend + # to the prefix and to one another, but do not attend to FAST action tokens + # final shape: [B, prefix_len+suffix_len, prefix_len+suffix_len] + # TODO: jadechoghari: put it in a function, and be more efficient + bsize = prefix_embs.shape[0] + prefix_len = prefix_pad_masks.shape[1] + suffix_len = suffix_pad_masks.shape[1] + total_len = prefix_len + suffix_len + + # start with 0 mask + combined_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=prefix_embs.device) + + # top-left: prefix attends to prefix (using the custom prefix attention mask) + combined_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks + + # bottom-right: suffix attends to suffix (using suffix attention mask) + combined_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_att_masks + + # bottom-left: suffix attends to prefix EXCEPT FAST tokens + # calculate where FAST tokens are in the prefix + prefix_without_fast_len = image_len + language_len + subtask_len + + # suffix can attend to images, language, and subtask tokens (but not FAST) + combined_att_2d_masks[:, prefix_len:, :prefix_without_fast_len] = True + + # apply padding masks to the combined attention mask + combined_pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + pad_2d_masks = combined_pad_masks[:, None, :] * combined_pad_masks[:, :, None] + att_2d_masks = combined_att_2d_masks & pad_2d_masks + if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 + position_ids = torch.cumsum(combined_pad_masks, dim=1) - 1 att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): - (_, suffix_out), _ = self.paligemma_with_expert.forward( + (prefix_out, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, position_ids=position_ids, past_key_values=None, @@ -902,15 +936,69 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` use_cache=False, adarms_cond=[None, adarms_cond], ) - return suffix_out + return prefix_out, suffix_out # TODO: jadechoghri # add subtask prediction loss # add fast action prediction loss - suffix_out = self._apply_checkpoint( + prefix_out, suffix_out = self._apply_checkpoint( forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) + # add fast action prediction loss + # Get logits for FAST action tokens using the FAST LM head + # only compute logits for the positions that predict FAST tokens + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # Targets are the FAST action tokens + fast_targets = action_tokens # (B, num_fast_embs) + + # extract logits for FAST token prediction + fast_hidden = prefix_out[:, -fast_targets.shape[1] :, :] + fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size) + + # Shift left for next-step prediction and shift target + # logits[:, i] predicts targets[:, i+1] + fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # shift logits left + fast_targets = fast_targets[:, 1:] # shift targets right + fast_action_masks = action_masks[:, 1:] # shift masks to match targets + + # compute cross-entropy loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) + fast_targets_flat = fast_targets.reshape(-1) + + fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat) + fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) + + # apply mask and compute mean loss + masked_fast_loss = fast_loss_per_token * fast_action_masks.float() + fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1) + + + # add subtask prediction loss + subtask_targets = subtask_tokens + + # extract logits for subtask token prediction + subtask_hidden = prefix_out[:, -(fast_len+subtask_len):-fast_len, :] + subtask_logits_for_pred = lm_head(subtask_hidden) # (B, num_subtask_embs, gemma_vocab_size) + + # Shift left for next-step prediction and shift target + # logits[:, i] predicts targets[:, i+1] + subtask_logits_for_pred = subtask_logits_for_pred[:, :-1, :] # shift logits left + subtask_targets = subtask_targets[:, 1:] # shift targets right + subtask_masks_shifted = subtask_masks[:, 1:] # shift masks to match targets + + # compute cross-entropy loss + logits_flat = subtask_logits_for_pred.reshape(-1, subtask_logits_for_pred.size(-1)) + targets_flat = subtask_targets.reshape(-1) + subtask_loss_per_token = loss_fct(logits_flat, targets_flat) + subtask_loss_per_token = subtask_loss_per_token.reshape(subtask_targets.shape) + + # apply mask and compute mean loss + masked_subtask_loss = subtask_loss_per_token * subtask_masks_shifted.float() + subtask_loss = masked_subtask_loss.sum() / subtask_masks_shifted.sum().clamp(min=1) + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) @@ -918,8 +1006,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return self.action_out_proj(suffix_out) v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + flow_loss = F.mse_loss(u_t, v_t, reduction="none") - return F.mse_loss(u_t, v_t, reduction="none") + return { + "flow_loss": flow_loss.mean(), + "action_ce_loss": fast_loss.mean(), + "subtask_ce_loss": subtask_loss, + "loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights + } @torch.no_grad() # see openpi `sample_actions` (slightly adapted) def sample_actions( @@ -948,8 +1042,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_embs, prefix_pad_masks, prefix_att_masks, _, _ = self.embed_prefix(images, img_masks, tokens, None, masks, None) + # prefix_att_masks is already a 2D attention mask from _create_custom_attention_mask + prefix_att_2d_masks = prefix_att_masks prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) @@ -1398,23 +1493,17 @@ class PI05FullPolicy(PreTrainedPolicy): actions = self.prepare_action(batch) # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions) + loss_dict = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions) - # Truncate losses to actual action dimensions - original_action_dim = self.config.output_features[ACTION].shape[0] - losses = losses[:, :, :original_action_dim] - - loss_dict = { - "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + # Extract the total loss + loss = loss_dict["loss"] + + # Prepare detailed loss dictionary for logging + detailed_loss_dict = { + "loss": loss.item(), + "flow_loss": loss_dict["flow_loss"].mean().item(), + "subtask_ce_loss": loss_dict["subtask_ce_loss"].item(), + "action_ce_loss": loss_dict["action_ce_loss"].item(), } - if reduction == "none": - # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) - loss_dict["loss"] = per_sample_loss.mean().item() - return per_sample_loss, loss_dict - else: - # Default: return scalar mean loss - loss = losses.mean() - loss_dict["loss"] = loss.item() - return loss, loss_dict + return loss, detailed_loss_dict