mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
add three losses: flow_mse, subtask_ce, action_ce
This commit is contained in:
@@ -730,9 +730,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||
total_T_images = 0
|
||||
num_subtask_embs = 0
|
||||
num_fast_embs = 0
|
||||
image_len = 0
|
||||
language_len = 0
|
||||
subtask_len = 0
|
||||
fast_len = 0
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
@@ -746,7 +747,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_mask_segments.append(('image', num_img_embs))
|
||||
total_T_images += num_img_embs
|
||||
image_len += num_img_embs
|
||||
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
@@ -774,8 +775,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# create subtask pad masks (non-zero tokens are valid)
|
||||
pad_masks.append(subtask_masks)
|
||||
|
||||
num_subtask_embs = subtask_emb.shape[1]
|
||||
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||
subtask_len = subtask_emb.shape[1]
|
||||
att_mask_segments.append(('subtask', subtask_len))
|
||||
|
||||
# Process FAST action tokens (discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
@@ -788,9 +789,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
fast_len = fast_action_tokens.shape[1]
|
||||
pad_masks.append(fast_action_masks)
|
||||
att_mask_segments.append(("fast", num_fast_embs))
|
||||
att_mask_segments.append(("fast", fast_len))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
@@ -811,7 +812,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# max_display_tokens=512 # Limit display for very long sequences
|
||||
# )
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
return embs, pad_masks, att_masks, image_len
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
@@ -872,29 +873,62 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, image_len = self.embed_prefix(images, img_masks, high_level_task_tokens, subtask_tokens, high_level_task_masks, subtask_masks, action_tokens, action_masks)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
#TODO jadechoghari
|
||||
# this attention part should be reworked
|
||||
breakpoint()
|
||||
suffix_att_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
# prefix_att_masks: shape [B, prefix_len, prefix_len]
|
||||
# suffix_att_masks: shape [B, suffix_len, suffix_len]
|
||||
# get: image_len, language_len, subtask_len, fast_len
|
||||
language_len = high_level_task_tokens.shape[1]
|
||||
subtask_len = subtask_tokens.shape[1]
|
||||
fast_len = action_tokens.shape[1]
|
||||
|
||||
# create the correct attention mask for combining the prefix and suffix
|
||||
# suffix embeddings from the action expert embeddings attend
|
||||
# to the prefix and to one another, but do not attend to FAST action tokens
|
||||
# final shape: [B, prefix_len+suffix_len, prefix_len+suffix_len]
|
||||
# TODO: jadechoghari: put it in a function, and be more efficient
|
||||
bsize = prefix_embs.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
total_len = prefix_len + suffix_len
|
||||
|
||||
# start with 0 mask
|
||||
combined_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=prefix_embs.device)
|
||||
|
||||
# top-left: prefix attends to prefix (using the custom prefix attention mask)
|
||||
combined_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks
|
||||
|
||||
# bottom-right: suffix attends to suffix (using suffix attention mask)
|
||||
combined_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_att_masks
|
||||
|
||||
# bottom-left: suffix attends to prefix EXCEPT FAST tokens
|
||||
# calculate where FAST tokens are in the prefix
|
||||
prefix_without_fast_len = image_len + language_len + subtask_len
|
||||
|
||||
# suffix can attend to images, language, and subtask tokens (but not FAST)
|
||||
combined_att_2d_masks[:, prefix_len:, :prefix_without_fast_len] = True
|
||||
|
||||
# apply padding masks to the combined attention mask
|
||||
combined_pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
pad_2d_masks = combined_pad_masks[:, None, :] * combined_pad_masks[:, :, None]
|
||||
att_2d_masks = combined_att_2d_masks & pad_2d_masks
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
position_ids = torch.cumsum(combined_pad_masks, dim=1) - 1
|
||||
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
(prefix_out, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
@@ -902,15 +936,69 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
return prefix_out, suffix_out
|
||||
|
||||
# TODO: jadechoghri
|
||||
# add subtask prediction loss
|
||||
# add fast action prediction loss
|
||||
suffix_out = self._apply_checkpoint(
|
||||
prefix_out, suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# add fast action prediction loss
|
||||
# Get logits for FAST action tokens using the FAST LM head
|
||||
# only compute logits for the positions that predict FAST tokens
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# Targets are the FAST action tokens
|
||||
fast_targets = action_tokens # (B, num_fast_embs)
|
||||
|
||||
# extract logits for FAST token prediction
|
||||
fast_hidden = prefix_out[:, -fast_targets.shape[1] :, :]
|
||||
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
|
||||
|
||||
# Shift left for next-step prediction and shift target
|
||||
# logits[:, i] predicts targets[:, i+1]
|
||||
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # shift logits left
|
||||
fast_targets = fast_targets[:, 1:] # shift targets right
|
||||
fast_action_masks = action_masks[:, 1:] # shift masks to match targets
|
||||
|
||||
# compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
|
||||
fast_targets_flat = fast_targets.reshape(-1)
|
||||
|
||||
fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat)
|
||||
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape)
|
||||
|
||||
# apply mask and compute mean loss
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
|
||||
|
||||
# add subtask prediction loss
|
||||
subtask_targets = subtask_tokens
|
||||
|
||||
# extract logits for subtask token prediction
|
||||
subtask_hidden = prefix_out[:, -(fast_len+subtask_len):-fast_len, :]
|
||||
subtask_logits_for_pred = lm_head(subtask_hidden) # (B, num_subtask_embs, gemma_vocab_size)
|
||||
|
||||
# Shift left for next-step prediction and shift target
|
||||
# logits[:, i] predicts targets[:, i+1]
|
||||
subtask_logits_for_pred = subtask_logits_for_pred[:, :-1, :] # shift logits left
|
||||
subtask_targets = subtask_targets[:, 1:] # shift targets right
|
||||
subtask_masks_shifted = subtask_masks[:, 1:] # shift masks to match targets
|
||||
|
||||
# compute cross-entropy loss
|
||||
logits_flat = subtask_logits_for_pred.reshape(-1, subtask_logits_for_pred.size(-1))
|
||||
targets_flat = subtask_targets.reshape(-1)
|
||||
subtask_loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||
subtask_loss_per_token = subtask_loss_per_token.reshape(subtask_targets.shape)
|
||||
|
||||
# apply mask and compute mean loss
|
||||
masked_subtask_loss = subtask_loss_per_token * subtask_masks_shifted.float()
|
||||
subtask_loss = masked_subtask_loss.sum() / subtask_masks_shifted.sum().clamp(min=1)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
@@ -918,8 +1006,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
return {
|
||||
"flow_loss": flow_loss.mean(),
|
||||
"action_ce_loss": fast_loss.mean(),
|
||||
"subtask_ce_loss": subtask_loss,
|
||||
"loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights
|
||||
}
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
@@ -948,8 +1042,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _ = self.embed_prefix(images, img_masks, tokens, None, masks, None)
|
||||
# prefix_att_masks is already a 2D attention mask from _create_custom_attention_mask
|
||||
prefix_att_2d_masks = prefix_att_masks
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
@@ -1398,23 +1493,17 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
|
||||
loss_dict = self.model.forward(images, img_masks, high_level_task_tokens, high_level_task_masks, subtask_tokens, subtask_masks, action_tokens, action_masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss_dict = {
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
# Extract the total loss
|
||||
loss = loss_dict["loss"]
|
||||
|
||||
# Prepare detailed loss dictionary for logging
|
||||
detailed_loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||
"subtask_ce_loss": loss_dict["subtask_ce_loss"].item(),
|
||||
"action_ce_loss": loss_dict["action_ce_loss"].item(),
|
||||
}
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
return loss, detailed_loss_dict
|
||||
|
||||
Reference in New Issue
Block a user