diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 9e1579a5d..ca71fa09d 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -538,11 +538,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - + # # FAST action token embedding and prediction head # self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width) # self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size) + from transformers import AutoTokenizer + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + "google/paligemma-3b-pt-224", + trust_remote_code=True, + ) # # Apply dtype conversion to FAST layers to match model precision # if config.dtype == "bfloat16": # self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16) @@ -1498,6 +1503,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` device = tokens.device lm_head = self.paligemma_with_expert.paligemma.lm_head + # add bos token after tokens + bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device) + tokens = torch.cat([tokens, bos_token], dim=1) + masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + # 1. Initial Embedding (Matches Training Prefix) # prefix_embs will include [Images, Language Prompt] prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast( @@ -1579,8 +1589,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` Efficient autoregressive decoding for FAST tokens using KV-caching. Only computes the prefix once, then incrementally generates tokens. """ - from transformers import AutoTokenizer - self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True) if max_decoding_steps is None: max_decoding_steps = self.config.max_action_tokens @@ -1588,6 +1596,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` device = tokens.device lm_head = self.paligemma_with_expert.paligemma.lm_head + bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device) + tokens = torch.cat([tokens, bos_token], dim=1) + masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + # 1. Initial Embedding (Matches Training Prefix) # prefix_embs will include [Images, Language Prompt] prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast( @@ -1615,49 +1627,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` adarms_cond=[None, None], ) - # Get BOS token and add it as the first token in action sequence - bos_id = self._paligemma_tokenizer.bos_token_id - bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device) + # # Get BOS token and add it as the first token in action sequence + # bos_id = self._paligemma_tokenizer.bos_token_id + # bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device) - # Embed BOS token - bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token) - bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1]) - if prefix_embs.dtype == torch.bfloat16: - bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16) + # # Embed BOS token + # bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token) + # bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1]) + # if prefix_embs.dtype == torch.bfloat16: + # bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16) # Track current sequence length for position IDs and maintain the padding mask current_seq_len = prefix_embs.shape[1] # Keep track of valid positions: prefix_pad_masks tells us which positions are valid current_pad_mask = prefix_pad_masks.clone() # (B, seq_len) - # Update padding mask for BOS token: it's always valid - current_pad_mask = torch.cat([ - current_pad_mask, - torch.ones((bsize, 1), dtype=torch.bool, device=device) - ], dim=1) # (B, seq_len+1) + # # Update padding mask for BOS token: it's always valid + # current_pad_mask = torch.cat([ + # current_pad_mask, + # torch.ones((bsize, 1), dtype=torch.bool, device=device) + # ], dim=1) # (B, seq_len+1) - # Position ID for BOS token (continues from where prefix ended) - bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device) + # # Position ID for BOS token (continues from where prefix ended) + # bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device) - # Attention mask for BOS token: attends to all valid prefix positions - bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1) - bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype) + # # Attention mask for BOS token: attends to all valid prefix positions + # bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1) + # bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype) - # Forward pass with BOS token (reusing cached KVs from prefix) - (bos_out, _), past_key_values = self.paligemma_with_expert.forward( - attention_mask=bos_att_4d, - position_ids=bos_position_id, - past_key_values=past_key_values, - inputs_embeds=[bos_token_emb, None], - use_cache=True, - adarms_cond=[None, None], - ) + # # Forward pass with BOS token (reusing cached KVs from prefix) + # (bos_out, _), past_key_values = self.paligemma_with_expert.forward( + # attention_mask=bos_att_4d, + # position_ids=bos_position_id, + # past_key_values=past_key_values, + # inputs_embeds=[bos_token_emb, None], + # use_cache=True, + # adarms_cond=[None, None], + # ) - # Update sequence length to account for BOS token - current_seq_len += 1 + # # Update sequence length to account for BOS token + # current_seq_len += 1 # Predict first action token from BOS token output - last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size) + last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size) if temperature > 0: probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) @@ -2399,6 +2411,14 @@ class PI05Policy(PreTrainedPolicy): # Clean tokens by removing everything after the first "|" (end-of-action marker) # and removing all occurrences of "Action: " token sequence + # assert that beginning contain "Action: " + for token_seq in decoded_tokens: + assert ( + len(token_seq) >= 2 + and token_seq[0] == "Action" + and token_seq[1] == ":" + ), f"Token sequence does not start with ['Action', ':']: {token_seq}" + cleaned_tokens = [] for token_seq in decoded_tokens: # Remove everything after "|" @@ -2483,12 +2503,11 @@ class PI05Policy(PreTrainedPolicy): max_decoding_steps = 256 # Sample action tokens autoregressively - action_tokens = self.model.sample_actions_fast_kv_cache( + action_tokens = self.model.sample_actions_fast( images, img_masks, tokens, masks, max_decoding_steps=max_decoding_steps, temperature=temperature, ) - # Detokenize action tokens to continuous actions action_horizon = self.config.n_action_steps action_dim = 7 diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d23b9d083..a9fb3f32f 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -173,6 +173,7 @@ def rollout( observation = env_preprocessor(observation) observation = preprocessor(observation) + with torch.inference_mode(): action = policy.select_action(observation) action = postprocessor(action)