add three losses: flow_mse, subtask_ce, action_ce

This commit is contained in:
Jade Choghari
2026-01-14 14:52:32 +00:00
parent b57504b89e
commit 72b0af4ed7
+129 -40
View File
@@ -730,9 +730,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs = [] embs = []
pad_masks = [] pad_masks = []
att_mask_segments = [] # Store info about each segment for custom mask creation att_mask_segments = [] # Store info about each segment for custom mask creation
total_T_images = 0 image_len = 0
num_subtask_embs = 0 language_len = 0
num_fast_embs = 0 subtask_len = 0
fast_len = 0
# Process images # Process images
for img, img_mask in zip(images, img_masks, strict=True): for img, img_mask in zip(images, img_masks, strict=True):
@@ -746,7 +747,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs.append(img_emb) embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_mask_segments.append(('image', num_img_embs)) att_mask_segments.append(('image', num_img_embs))
total_T_images += num_img_embs image_len += num_img_embs
# Process language instruction tokens # Process language instruction tokens
def lang_embed_func(tokens): def lang_embed_func(tokens):
@@ -774,8 +775,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# create subtask pad masks (non-zero tokens are valid) # create subtask pad masks (non-zero tokens are valid)
pad_masks.append(subtask_masks) pad_masks.append(subtask_masks)
num_subtask_embs = subtask_emb.shape[1] subtask_len = subtask_emb.shape[1]
att_mask_segments.append(('subtask', num_subtask_embs)) att_mask_segments.append(('subtask', subtask_len))
# Process FAST action tokens (discrete token IDs) # Process FAST action tokens (discrete token IDs)
if fast_action_tokens is not None: if fast_action_tokens is not None:
@@ -788,9 +789,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb) embs.append(fast_action_emb)
num_fast_embs = fast_action_tokens.shape[1] fast_len = fast_action_tokens.shape[1]
pad_masks.append(fast_action_masks) pad_masks.append(fast_action_masks)
att_mask_segments.append(("fast", num_fast_embs)) att_mask_segments.append(("fast", fast_len))
embs = torch.cat(embs, dim=1) embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1) pad_masks = torch.cat(pad_masks, dim=1)
@@ -811,7 +812,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# max_display_tokens=512 # Limit display for very long sequences # max_display_tokens=512 # Limit display for very long sequences
# ) # )
return embs, pad_masks, att_masks return embs, pad_masks, att_masks, image_len
def embed_suffix(self, noisy_actions, timestep): def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" """Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
@@ -872,29 +873,62 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks) prefix_embs, prefix_pad_masks, prefix_att_masks, image_len = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
#TODO jadechoghari #TODO jadechoghari
# this attention part should be reworked # this attention part should be reworked
breakpoint() suffix_att_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
# prefix_att_masks: shape [B, prefix_len, prefix_len]
# suffix_att_masks: shape [B, suffix_len, suffix_len]
# get: image_len, language_len, subtask_len, fast_len
language_len = high_level_task_tokens.shape[1]
subtask_len = subtask_tokens.shape[1]
fast_len = action_tokens.shape[1]
# create the correct attention mask for combining the prefix and suffix
# suffix embeddings from the action expert embeddings attend
# to the prefix and to one another, but do not attend to FAST action tokens
# final shape: [B, prefix_len+suffix_len, prefix_len+suffix_len]
# TODO: jadechoghari: put it in a function, and be more efficient
bsize = prefix_embs.shape[0]
prefix_len = prefix_pad_masks.shape[1]
suffix_len = suffix_pad_masks.shape[1]
total_len = prefix_len + suffix_len
# start with 0 mask
combined_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=prefix_embs.device)
# top-left: prefix attends to prefix (using the custom prefix attention mask)
combined_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks
# bottom-right: suffix attends to suffix (using suffix attention mask)
combined_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_att_masks
# bottom-left: suffix attends to prefix EXCEPT FAST tokens
# calculate where FAST tokens are in the prefix
prefix_without_fast_len = image_len + language_len + subtask_len
# suffix can attend to images, language, and subtask tokens (but not FAST)
combined_att_2d_masks[:, prefix_len:, :prefix_without_fast_len] = True
# apply padding masks to the combined attention mask
combined_pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
pad_2d_masks = combined_pad_masks[:, None, :] * combined_pad_masks[:, :, None]
att_2d_masks = combined_att_2d_masks & pad_2d_masks
if ( if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16 == torch.bfloat16
): ):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16) suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
position_ids = torch.cumsum(combined_pad_masks, dim=1) - 1
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward( (prefix_out, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d, attention_mask=att_2d_masks_4d,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
@@ -902,15 +936,69 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=False, use_cache=False,
adarms_cond=[None, adarms_cond], adarms_cond=[None, adarms_cond],
) )
return suffix_out return prefix_out, suffix_out
# TODO: jadechoghri # TODO: jadechoghri
# add subtask prediction loss # add subtask prediction loss
# add fast action prediction loss # add fast action prediction loss
suffix_out = self._apply_checkpoint( prefix_out, suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
) )
# add fast action prediction loss
# Get logits for FAST action tokens using the FAST LM head
# only compute logits for the positions that predict FAST tokens
lm_head = self.paligemma_with_expert.paligemma.lm_head
# Targets are the FAST action tokens
fast_targets = action_tokens # (B, num_fast_embs)
# extract logits for FAST token prediction
fast_hidden = prefix_out[:, -fast_targets.shape[1] :, :]
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
# Shift left for next-step prediction and shift target
# logits[:, i] predicts targets[:, i+1]
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # shift logits left
fast_targets = fast_targets[:, 1:] # shift targets right
fast_action_masks = action_masks[:, 1:] # shift masks to match targets
# compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
fast_targets_flat = fast_targets.reshape(-1)
fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat)
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape)
# apply mask and compute mean loss
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
# add subtask prediction loss
subtask_targets = subtask_tokens
# extract logits for subtask token prediction
subtask_hidden = prefix_out[:, -(fast_len+subtask_len):-fast_len, :]
subtask_logits_for_pred = lm_head(subtask_hidden) # (B, num_subtask_embs, gemma_vocab_size)
# Shift left for next-step prediction and shift target
# logits[:, i] predicts targets[:, i+1]
subtask_logits_for_pred = subtask_logits_for_pred[:, :-1, :] # shift logits left
subtask_targets = subtask_targets[:, 1:] # shift targets right
subtask_masks_shifted = subtask_masks[:, 1:] # shift masks to match targets
# compute cross-entropy loss
logits_flat = subtask_logits_for_pred.reshape(-1, subtask_logits_for_pred.size(-1))
targets_flat = subtask_targets.reshape(-1)
subtask_loss_per_token = loss_fct(logits_flat, targets_flat)
subtask_loss_per_token = subtask_loss_per_token.reshape(subtask_targets.shape)
# apply mask and compute mean loss
masked_subtask_loss = subtask_loss_per_token * subtask_masks_shifted.float()
subtask_loss = masked_subtask_loss.sum() / subtask_masks_shifted.sum().clamp(min=1)
suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32) suffix_out = suffix_out.to(dtype=torch.float32)
@@ -918,8 +1006,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return self.action_out_proj(suffix_out) return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
return F.mse_loss(u_t, v_t, reduction="none") return {
"flow_loss": flow_loss.mean(),
"action_ce_loss": fast_loss.mean(),
"subtask_ce_loss": subtask_loss,
"loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights
}
@torch.no_grad() # see openpi `sample_actions` (slightly adapted) @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions( def sample_actions(
@@ -948,8 +1042,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
) # Use config max_action_dim for internal processing ) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) prefix_embs, prefix_pad_masks, prefix_att_masks, _, _ = self.embed_prefix(images, img_masks, tokens, None, masks, None)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) # prefix_att_masks is already a 2D attention mask from _create_custom_attention_mask
prefix_att_2d_masks = prefix_att_masks
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
@@ -1398,23 +1493,17 @@ class PI05FullPolicy(PreTrainedPolicy):
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05) # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions) loss_dict = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
# Truncate losses to actual action dimensions # Extract the total loss
original_action_dim = self.config.output_features[ACTION].shape[0] loss = loss_dict["loss"]
losses = losses[:, :, :original_action_dim]
# Prepare detailed loss dictionary for logging
loss_dict = { detailed_loss_dict = {
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), "loss": loss.item(),
"flow_loss": loss_dict["flow_loss"].mean().item(),
"subtask_ce_loss": loss_dict["subtask_ce_loss"].item(),
"action_ce_loss": loss_dict["action_ce_loss"].item(),
} }
if reduction == "none": return loss, detailed_loss_dict
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict