mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
make it work
This commit is contained in:
@@ -538,11 +538,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
|
||||
# # FAST action token embedding and prediction head
|
||||
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# # Apply dtype conversion to FAST layers to match model precision
|
||||
# if config.dtype == "bfloat16":
|
||||
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
@@ -1498,6 +1503,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# add bos token after tokens
|
||||
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
|
||||
tokens = torch.cat([tokens, bos_token], dim=1)
|
||||
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
@@ -1579,8 +1589,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
Efficient autoregressive decoding for FAST tokens using KV-caching.
|
||||
Only computes the prefix once, then incrementally generates tokens.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
|
||||
if max_decoding_steps is None:
|
||||
max_decoding_steps = self.config.max_action_tokens
|
||||
|
||||
@@ -1588,6 +1596,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
|
||||
tokens = torch.cat([tokens, bos_token], dim=1)
|
||||
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
@@ -1615,49 +1627,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Get BOS token and add it as the first token in action sequence
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
|
||||
# # Get BOS token and add it as the first token in action sequence
|
||||
# bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
|
||||
|
||||
# Embed BOS token
|
||||
bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
|
||||
bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
|
||||
# # Embed BOS token
|
||||
# bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
|
||||
# bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
|
||||
# if prefix_embs.dtype == torch.bfloat16:
|
||||
# bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Track current sequence length for position IDs and maintain the padding mask
|
||||
current_seq_len = prefix_embs.shape[1]
|
||||
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
|
||||
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
|
||||
|
||||
# Update padding mask for BOS token: it's always valid
|
||||
current_pad_mask = torch.cat([
|
||||
current_pad_mask,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1) # (B, seq_len+1)
|
||||
# # Update padding mask for BOS token: it's always valid
|
||||
# current_pad_mask = torch.cat([
|
||||
# current_pad_mask,
|
||||
# torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
# ], dim=1) # (B, seq_len+1)
|
||||
|
||||
# Position ID for BOS token (continues from where prefix ended)
|
||||
bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
# # Position ID for BOS token (continues from where prefix ended)
|
||||
# bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
|
||||
# Attention mask for BOS token: attends to all valid prefix positions
|
||||
bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
|
||||
# # Attention mask for BOS token: attends to all valid prefix positions
|
||||
# bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
# bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
|
||||
|
||||
# Forward pass with BOS token (reusing cached KVs from prefix)
|
||||
(bos_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=bos_att_4d,
|
||||
position_ids=bos_position_id,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[bos_token_emb, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
# # Forward pass with BOS token (reusing cached KVs from prefix)
|
||||
# (bos_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
# attention_mask=bos_att_4d,
|
||||
# position_ids=bos_position_id,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=[bos_token_emb, None],
|
||||
# use_cache=True,
|
||||
# adarms_cond=[None, None],
|
||||
# )
|
||||
|
||||
# Update sequence length to account for BOS token
|
||||
current_seq_len += 1
|
||||
# # Update sequence length to account for BOS token
|
||||
# current_seq_len += 1
|
||||
|
||||
# Predict first action token from BOS token output
|
||||
last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
@@ -2399,6 +2411,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Clean tokens by removing everything after the first "|" (end-of-action marker)
|
||||
# and removing all occurrences of "Action: " token sequence
|
||||
# assert that beginning contain "Action: "
|
||||
for token_seq in decoded_tokens:
|
||||
assert (
|
||||
len(token_seq) >= 2
|
||||
and token_seq[0] == "Action"
|
||||
and token_seq[1] == ":"
|
||||
), f"Token sequence does not start with ['Action', ':']: {token_seq}"
|
||||
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
# Remove everything after "|"
|
||||
@@ -2483,12 +2503,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
max_decoding_steps = 256
|
||||
|
||||
# Sample action tokens autoregressively
|
||||
action_tokens = self.model.sample_actions_fast_kv_cache(
|
||||
action_tokens = self.model.sample_actions_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
max_decoding_steps=max_decoding_steps,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Detokenize action tokens to continuous actions
|
||||
action_horizon = self.config.n_action_steps
|
||||
action_dim = 7
|
||||
|
||||
@@ -173,6 +173,7 @@ def rollout(
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
Reference in New Issue
Block a user