make it work

This commit is contained in:
Jade Choghari
2025-12-29 17:34:18 +01:00
parent 23d4846423
commit 995a46b302
2 changed files with 55 additions and 35 deletions
+54 -35
View File
@@ -538,11 +538,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# # FAST action token embedding and prediction head
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
)
# # Apply dtype conversion to FAST layers to match model precision
# if config.dtype == "bfloat16":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
@@ -1498,6 +1503,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
# add bos token after tokens
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
tokens = torch.cat([tokens, bos_token], dim=1)
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
@@ -1579,8 +1589,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
Efficient autoregressive decoding for FAST tokens using KV-caching.
Only computes the prefix once, then incrementally generates tokens.
"""
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
if max_decoding_steps is None:
max_decoding_steps = self.config.max_action_tokens
@@ -1588,6 +1596,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
tokens = torch.cat([tokens, bos_token], dim=1)
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
@@ -1615,49 +1627,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
adarms_cond=[None, None],
)
# Get BOS token and add it as the first token in action sequence
bos_id = self._paligemma_tokenizer.bos_token_id
bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
# # Get BOS token and add it as the first token in action sequence
# bos_id = self._paligemma_tokenizer.bos_token_id
# bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
# Embed BOS token
bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
# # Embed BOS token
# bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
# bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
# if prefix_embs.dtype == torch.bfloat16:
# bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
# Track current sequence length for position IDs and maintain the padding mask
current_seq_len = prefix_embs.shape[1]
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
# Update padding mask for BOS token: it's always valid
current_pad_mask = torch.cat([
current_pad_mask,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1) # (B, seq_len+1)
# # Update padding mask for BOS token: it's always valid
# current_pad_mask = torch.cat([
# current_pad_mask,
# torch.ones((bsize, 1), dtype=torch.bool, device=device)
# ], dim=1) # (B, seq_len+1)
# Position ID for BOS token (continues from where prefix ended)
bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# # Position ID for BOS token (continues from where prefix ended)
# bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# Attention mask for BOS token: attends to all valid prefix positions
bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
# # Attention mask for BOS token: attends to all valid prefix positions
# bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
# bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
# Forward pass with BOS token (reusing cached KVs from prefix)
(bos_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=bos_att_4d,
position_ids=bos_position_id,
past_key_values=past_key_values,
inputs_embeds=[bos_token_emb, None],
use_cache=True,
adarms_cond=[None, None],
)
# # Forward pass with BOS token (reusing cached KVs from prefix)
# (bos_out, _), past_key_values = self.paligemma_with_expert.forward(
# attention_mask=bos_att_4d,
# position_ids=bos_position_id,
# past_key_values=past_key_values,
# inputs_embeds=[bos_token_emb, None],
# use_cache=True,
# adarms_cond=[None, None],
# )
# Update sequence length to account for BOS token
current_seq_len += 1
# # Update sequence length to account for BOS token
# current_seq_len += 1
# Predict first action token from BOS token output
last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size)
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
@@ -2399,6 +2411,14 @@ class PI05Policy(PreTrainedPolicy):
# Clean tokens by removing everything after the first "|" (end-of-action marker)
# and removing all occurrences of "Action: " token sequence
# assert that beginning contain "Action: "
for token_seq in decoded_tokens:
assert (
len(token_seq) >= 2
and token_seq[0] == "Action"
and token_seq[1] == ":"
), f"Token sequence does not start with ['Action', ':']: {token_seq}"
cleaned_tokens = []
for token_seq in decoded_tokens:
# Remove everything after "|"
@@ -2483,12 +2503,11 @@ class PI05Policy(PreTrainedPolicy):
max_decoding_steps = 256
# Sample action tokens autoregressively
action_tokens = self.model.sample_actions_fast_kv_cache(
action_tokens = self.model.sample_actions_fast(
images, img_masks, tokens, masks,
max_decoding_steps=max_decoding_steps,
temperature=temperature,
)
# Detokenize action tokens to continuous actions
action_horizon = self.config.n_action_steps
action_dim = 7
+1
View File
@@ -173,6 +173,7 @@ def rollout(
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)