add fast tokenizer support

This commit is contained in:
Jade Choghari
2025-12-16 11:28:27 +00:00
parent fddd044306
commit 8e05dc9a7a
13 changed files with 1081 additions and 204 deletions
@@ -0,0 +1,138 @@
#!/usr/bin/env python
"""
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
This example shows how to:
1. Load a dataset with action data
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
3. Access both the tokenized actions and the attention mask
4. Decode tokenized actions back to their original form
"""
import torch
from transformers import AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
from lerobot.utils.constants import ACTION_TOKEN_MASK
# Define delta timestamps for the dataset
delta_timestamps = {
'action': [
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
]
}
# Load the dataset
print("Loading dataset...")
dataset = LeRobotDataset(
repo_id="local",
root="/fsx/jade_choghari/outputs/pgen_annotations1",
delta_timestamps=delta_timestamps
)
# Create a dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
# Get a batch of data
batch = next(iter(dataloader))
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
print(f"\nOriginal action shape: {action_data.shape}")
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
print("\n" + "="*80)
print("Method 1: Direct tokenizer usage")
print("="*80)
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize directly
tokens = tokenizer(action_data)
print(f"\nDirect tokenization result type: {type(tokens)}")
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
# Decode
decoded_actions = tokenizer.decode(tokens)
print(f"Decoded actions shape: {decoded_actions.shape}")
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
print("\n" + "="*80)
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
print("="*80)
# Create the action tokenizer processor step
action_tokenizer_processor = ActionTokenizerProcessorStep(
tokenizer_name="physical-intelligence/fast",
trust_remote_code=True,
max_action_tokens=32, # Maximum number of tokens per action
)
# Create a transition with the action data
transition = {
TransitionKey.ACTION: action_data,
TransitionKey.OBSERVATION: {}, # Empty for this example
}
# Apply the processor
processed_transition = action_tokenizer_processor(transition)
# Extract tokenized actions and mask
tokenized_actions = processed_transition[TransitionKey.ACTION]
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
action_mask = complementary_data[ACTION_TOKEN_MASK]
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
print(f"Action mask dtype: {action_mask.dtype}")
# Show token statistics
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
print(f"First sample mask: {action_mask[0]}")
num_real_tokens = action_mask[0].sum().item()
print(f"Number of real tokens (non-padding): {num_real_tokens}")
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
# Decode using the mask
print("\nDecoding tokenized actions...")
decoded_with_processor = tokenizer.decode(tokenized_actions)
print(f"Decoded actions shape: {decoded_with_processor.shape}")
# Calculate reconstruction error
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
# Show that masking works correctly
print("\n" + "="*80)
print("Mask demonstration")
print("="*80)
for i in range(min(4, tokenized_actions.shape[0])):
mask_i = action_mask[i]
num_real = mask_i.sum().item()
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
print("\n" + "="*80)
print("Action tokenization example completed successfully!")
print("="*80)
+25
View File
@@ -0,0 +1,25 @@
import numpy as np
from transformers import AutoProcessor
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
# Load the tokenizer from the Hugging Face hub
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize & decode action chunks (we use dummy data here)
action_data = batch["action"] # one batch of action chunks
tokens = tokenizer(action_data) # tokens = list[int]
decoded_actions = tokenizer.decode(tokens)
print("tokenized actions: ", tokens)
+3 -4
View File
@@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
@@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader(
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
breakpoint()
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
breakpoint()
# import requests
# from PIL import Image
# from transformers import AutoProcessor
+159
View File
@@ -0,0 +1,159 @@
## One-sentence answer
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
Everything else youve seen so far was just metadata.
---
## What goes in
### Inputs
```python
prefix_pad_masks # shape [B, L]
prefix_att_masks # shape [B, L]
```
Where:
* `prefix_pad_masks[b, i] = True`
→ token `i` exists (not padding)
* `prefix_att_masks[b, i] = False`
→ token `i` is **bidirectional**
* `prefix_att_masks[b, i] = True`
→ token `i` is **causal (autoregressive)**
---
## What comes out
```python
att_2d_prefix # shape [B, L, L]
```
Each entry:
```text
att_2d_prefix[b, i, j] = True
```
means:
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
---
## How it is constructed (conceptually)
For **each batch b**, **each query position i**, **each key position j**:
```python
if not prefix_pad_masks[b, j]:
att[b, i, j] = False # cannot attend to padding
else if not prefix_att_masks[b, i]:
att[b, i, j] = True # bidirectional token → can see all real tokens
else:
att[b, i, j] = (j <= i) # causal token → can see only past + itself
```
Thats it.
---
## Tiny concrete example (exactly matching your code)
Suppose:
```python
prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]
```
Tokens:
```
0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD
```
---
### Resulting `att_2d_prefix[0]`
`✓ = True, ✗ = False`
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
| ---------- | - | - | - | - | - | - |
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
---
## Why this matters for your training code
This line:
```python
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
```
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
This is **exactly what Paligemma uses inside self-attention**.
---
## Key implications (VERY important)
### 1️⃣ This mask does **not isolate token groups**
* Bidirectional tokens can attend to **everything**
* Causal tokens only restrict *their own row*
So **flow/action tokens must be blocked separately**.
---
### 2️⃣ This is why your AR subtask prediction works
* Subtask tokens are causal
* Output at position `i` predicts token `i+1`
* Padding is fully ignored
---
### 3️⃣ Inference behavior
When `subtask_tokens = None`:
* `prefix_att_masks` contains only `False`
* `att_2d_prefix` becomes **fully bidirectional**
* No AR behavior remains
Exactly what you want.
---
## One-sentence takeaway (commit this)
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
If you want next, I can:
* inspect `make_att_2d_masks()` source with you
* show how to block **flow → subtask** attention
* explain how this changes when suffix tokens are added
* help you refactor this into a cleaner “grouped attention” API
Youre now at the point where the models behavior should feel *predictable*, not magical.
+1
View File
@@ -0,0 +1 @@
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
-64
View File
@@ -1,64 +0,0 @@
Fine tune output
(Pdb) images[2].mean()
tensor(-1., device='cuda:0')
(Pdb) images[1].mean()
tensor(-0.5780, device='cuda:0')
(Pdb) images[0].mean()
tensor(-0.7716, device='cuda:0')
(Pdb) (Pdb) high_level_task[0]
tensor([ 2, 7978, 2403, 6911, 235292, 5651, 3124, 573, 18571,
7762, 6643, 573, 9010, 72993, 21810, 4894, 3040, 235292,
235248, 235274, 235274, 235274, 728, 235274, 235248, 235284, 235308,
235308, 235248, 235274, 235318, 235315, 235248, 235274, 235310, 235318,
235248, 235284, 235318, 235248, 235274, 235284, 235321, 235248, 235274,
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235289,
4284, 8277, 235292, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0], device='cuda:0')
(Pdb) subtask_tokens[0]
tensor([ 2, 28040, 7762, 14574, 6643, 9010, 37901, 21810, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
device='cuda:0')
(Pdb) actions.shape
torch.Size([4, 50, 32])
(Pdb) actions.mean()
tensor(0.0143, device='cuda:0')
(Pdb)
Inference:
@@ -37,6 +37,9 @@ class PI05Config(PreTrainedConfig):
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
max_action_tokens: int = 32
fast_vocab_size: int = 2048
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
+476 -102
View File
@@ -537,6 +537,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# FAST action token embedding and prediction head
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# Apply dtype conversion to FAST layers to match model precision
if config.dtype == "bfloat16":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
elif config.dtype == "float32":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
@@ -592,6 +604,194 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
result = result.to(dtype=dtype)
return result
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for the new attention pattern.
Attention rules:
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
- FAST: attend to images + language + subtask, causal among themselves
Args:
att_mask_segments: List of (type, length) tuples
pad_masks: Padding masks [B, total_seq_len]
bsize: Batch size
Returns:
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
"""
total_len = sum(length for _, length in att_mask_segments)
device = pad_masks.device
# Initialize attention mask as False (cannot attend)
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# Track positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Apply attention rules
for i, (query_type, query_start, query_end) in enumerate(positions):
for j, (key_type, key_start, key_end) in enumerate(positions):
# Images and Language can attend to each other bidirectionally
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend to images + language
elif query_type == 'subtask' and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend causally to themselves
elif query_type == 'subtask' and key_type == 'subtask':
# Create causal mask for subtask tokens
subtask_len = query_end - query_start
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# FAST tokens attend to images + language + subtask
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend causally to themselves
elif query_type == 'fast' and key_type == 'fast':
fast_len = query_end - query_start
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# Apply padding masks
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def visualize_attention_mask(
self,
att_mask_segments,
att_2d_masks,
save_path,
batch_idx=0,
dpi=150,
max_display_tokens=None
):
"""Visualize the attention mask with labeled segments.
Args:
att_mask_segments: List of (type, length) tuples defining the segments
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
save_path: Path where to save the visualization image
batch_idx: Which batch item to visualize (default: 0)
dpi: DPI for the saved image (default: 150)
max_display_tokens: Maximum number of tokens to display (for very long sequences)
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
except ImportError:
logging.warning("matplotlib not available, skipping attention mask visualization")
return
# Extract the mask for the specified batch
mask = att_2d_masks[batch_idx].cpu().float().numpy()
# If sequence is too long, downsample for visualization
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
# Simple downsampling by taking every Nth token
step = mask.shape[0] // max_display_tokens
mask = mask[::step, ::step]
# Adjust segments accordingly
att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments]
# Calculate positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Create figure
fig, ax = plt.subplots(figsize=(12, 10))
# Create custom colormap: white for False (no attention), blue for True (attention)
colors = ['white', '#2E86AB']
n_bins = 2
cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins)
# Display the mask
im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Attention Enabled', rotation=270, labelpad=20)
cbar.set_ticks([0.25, 0.75])
cbar.set_ticklabels(['No', 'Yes'])
# Define colors for each segment type
segment_colors = {
'image': '#A23B72',
'language': '#F18F01',
'subtask': '#C73E1D',
'fast': '#6A994E'
}
# Draw segment boundaries and labels
for seg_type, start, end in positions:
color = segment_colors.get(seg_type, '#666666')
# Draw vertical lines for columns (keys)
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Draw horizontal lines for rows (queries)
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Add labels at the top
mid_pos = (start + end) / 2
ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})",
ha='center', va='top', fontsize=10, fontweight='bold', color=color)
# Add labels on the left
ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})",
ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0)
# Set axis labels
ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold')
ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold')
ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)',
fontsize=14, fontweight='bold', pad=20)
# Create legend for segment types
legend_patches = []
attention_rules = {
'image': 'Bidirectional with lang',
'language': 'Bidirectional with images',
'subtask': 'Attends to img+lang, causal self',
'fast': 'Attends to all, causal self'
}
for seg_type, color in segment_colors.items():
if any(seg[0] == seg_type for seg in att_mask_segments):
rule = attention_rules.get(seg_type, '')
legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}'))
ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0),
framealpha=0.9, fontsize=9)
# Adjust layout and save
plt.tight_layout()
# Ensure the directory exists
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
plt.close()
logging.info(f"Attention mask visualization saved to: {save_path}")
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
@@ -609,8 +809,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
self,
images,
img_masks,
tokens,
subtask_tokens,
masks,
subtask_masks,
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
Args:
@@ -619,17 +827,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
tokens: Language instruction tokens
subtask_tokens: Subtask tokens to predict (can be None for inference)
masks: Attention masks for tokens
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
fast_action_masks: Padding masks for FAST action tokens (can be None)
Returns:
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
pad_masks: Padding masks
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
att_masks: Custom 2D attention mask implementing the required pattern
total_T_images: Total number of image tokens
num_subtask_embs: Number of subtask token embeddings
num_fast_embs: Number of FAST action token embeddings
"""
embs = []
pad_masks = []
att_masks = []
att_mask_segments = [] # Store info about each segment for custom mask creation
total_T_images = 0
num_subtask_embs = 0
num_fast_embs = 0
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
@@ -642,7 +856,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
att_mask_segments.append(('image', num_img_embs))
total_T_images += num_img_embs
# Process language instruction tokens
@@ -656,7 +870,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
att_mask_segments.append(('language', num_lang_embs))
# Process subtask tokens if provided (these are predicted, so use causal masking)
if subtask_tokens is not None:
@@ -672,18 +886,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
pad_masks.append(subtask_masks)
num_subtask_embs = subtask_emb.shape[1]
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
# and previous subtask tokens
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
att_mask_segments.append(('subtask', num_subtask_embs))
# Process FAST action tokens if provided (these are discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.fast_action_embedding(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
# Use provided mask or create default (all valid)
if fast_action_masks is not None:
fast_pad_mask = fast_action_masks
else:
bsize = fast_action_tokens.shape[0]
num_fast_embs = fast_action_tokens.shape[1]
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_pad_mask)
att_mask_segments.append(('fast', num_fast_embs))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
# Create custom 2D attention mask
# Attention rules:
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
# - FAST: attend to images + language + subtask, causal among themselves
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
return embs, pad_masks, att_masks, total_T_images
# # Optionally visualize the attention mask
# self.visualize_attention_mask(
# att_mask_segments=att_mask_segments,
# att_2d_masks=att_masks,
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
# batch_idx=0,
# max_display_tokens=512 # Limit display for very long sequences
# )
return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
@@ -732,8 +977,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return embs, pad_masks, att_masks, adarms_cond
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks)
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss.
Args:
@@ -743,7 +988,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
high_level_task_masks: Attention masks for high_level_task
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
subtask_masks: Attention masks for subtask_tokens
actions: Ground truth actions
actions: Ground truth actions [B, chunk_size, action_dim]
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
noise: Optional noise for flow matching
time: Optional time for flow matching
"""
@@ -757,39 +1004,45 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# Embed prefix (images + high_level_task + subtask_tokens)
# Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks
# Initialize FAST loss to 0 (will be computed only if FAST tokens are provided)
fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype)
# ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ==========
# Only run this pass if FAST action tokens are provided
if fast_action_tokens is not None and fast_action_masks is not None:
# Embed prefix (images + high_level_task + subtask_tokens + FAST tokens)
# FAST tokens are provided as discrete token IDs
prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16)
# Prepare attention masks for prefix-only pass (for subtask token prediction)
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# Prepare attention masks for prefix pass with FAST tokens
position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1
att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype)
# prefix-only transformer run for subtask token prediction
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
# Forward pass through paligemma for subtask + FAST prediction
(prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_with_fast_4d,
position_ids=position_ids_prefix_with_fast,
past_key_values=None,
inputs_embeds=[prefix_embs, None], # SUFFIX = None
inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None
use_cache=False,
adarms_cond=[None, None],
)
# LM HEAD → SUBTASK LOGITS
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab)
# Extract logits for subtask token prediction
# In autoregressive modeling, output at position i predicts token at position i+1
# So we take logits from one position earlier:
# - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0]
# - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc.
T_high_level_task = high_level_task.size(1)
T_subtask = subtask_tokens.size(1)
start_index = total_T_images + T_high_level_task
@@ -797,35 +1050,137 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
targets = subtask_tokens # (B, T_subtask)
# Compute cross-entropy loss
# Compute cross-entropy loss for subtask
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
# Reshape for loss computation
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
targets_flat = targets.reshape(-1) # (B*T_subtask)
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
# Apply mask and compute mean loss over valid tokens
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
targets_flat = targets.reshape(-1)
loss_per_token = loss_fct(logits_flat, targets_flat)
loss_per_token = loss_per_token.reshape(targets.shape)
masked_loss = loss_per_token * subtask_masks.float()
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
# Extract outputs for FAST action token prediction and compute auxiliary loss
# FAST outputs start after subtask tokens
# Similar to subtask, we use autoregressive prediction where position i predicts token i+1
fast_start_index = end_index
fast_end_index = fast_start_index + num_fast_embs
# Get logits for FAST action tokens using the FAST LM head
fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size)
# Extract logits for FAST token prediction (autoregressive: position i predicts token i+1)
# - Position (fast_start_index-1) predicts fast_action_tokens[0]
# - Position (fast_start_index) predicts fast_action_tokens[1], etc.
fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size)
# Compute cross-entropy loss for FAST action tokens
fast_targets = fast_action_tokens # (B, max_action_tokens)
loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none')
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size)
fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens)
fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens)
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens)
# Apply mask and compute mean loss over valid tokens
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
else:
# If no FAST tokens provided, compute subtask loss without FAST tokens
# This is the fallback for backward compatibility
prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16)
position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs_for_subtask, None],
use_cache=False,
adarms_cond=[None, None],
)
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_out)
T_high_level_task = high_level_task.size(1)
T_subtask = subtask_tokens.size(1)
start_index = total_T_images + T_high_level_task
end_index = start_index + T_subtask
logits_subtask = logits[:, start_index-1:end_index-1, :]
targets = subtask_tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
targets_flat = targets.reshape(-1)
loss_per_token = loss_fct(logits_flat, targets_flat)
loss_per_token = loss_per_token.reshape(targets.shape)
masked_loss = loss_per_token * subtask_masks.float()
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
# ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ==========
# Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens)
prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16)
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
# For the flow matching pass, we need custom attention where:
# - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention)
# - suffix attends to all prefix + causal to itself
# We'll construct this by extending prefix_att_masks_no_fast to include suffix
# prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len]
# We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len]
bsize = prefix_pad_masks_no_fast.shape[0]
prefix_len = prefix_pad_masks_no_fast.shape[1]
suffix_len = suffix_pad_masks.shape[1]
total_len = prefix_len + suffix_len
device = prefix_pad_masks_no_fast.device
# Create full attention mask
full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# Copy prefix attention pattern
full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast
# Suffix attends to all prefix
full_att_2d_masks[:, prefix_len:, :prefix_len] = True
# Suffix has causal attention among itself
suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device))
full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :]
# Apply padding masks
pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1)
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
full_att_2d_masks = full_att_2d_masks & pad_2d_masks
# Prepare attention masks for full forward pass (prefix + suffix)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
@@ -836,11 +1191,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=False,
adarms_cond=[None, adarms_cond],
)
# prefix_out to be used for the language head
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
@@ -856,78 +1210,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return {
"flow_loss": fm_loss,
"subtask_loss": subtask_loss,
"loss": 10 * fm_loss.mean() + subtask_loss,
"fast_loss": fast_loss,
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
}
@torch.no_grad()
def _generate_subtask_tokens(
self, images, img_masks, tokens, masks, tokenizer, max_length, device
):
"""Generate subtask tokens autoregressively using next token prediction."""
bsize = tokens.shape[0]
# Get lm_head for token generation
lm_head = self.paligemma_with_expert.paligemma.lm_head
# Embed prefix without subtask tokens first
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None,
fast_action_tokens=None, fast_action_masks=None
)
# Initialize generated tokens list - start with BOS token or first token after instruction
# For PaliGemma, we'll start generation and accumulate tokens
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
for t in range(max_length):
# Prepare attention masks for current prefix
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# tracking mask: False = still generating, True = finished
finished = torch.zeros(bsize, dtype=torch.bool, device=device)
for t in range(max_length):
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
# Forward pass through model to get logits
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
adarms_cond=[None, None],
# ...
)
# Get logits from the last position
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
next_token_logits = logits[:, -1, :] # (B, vocab)
# Greedy decoding - take the most likely token
logits = lm_head(prefix_out)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
# Store generated token
# 1. if a row was already finished, force the next token to be PAD (0)
next_token = torch.where(finished, torch.tensor(0, device=device), next_token)
# 2. store the token
generated_tokens[:, t] = next_token
# Check for EOS token - if all batches have generated EOS, stop
# 3. update the finished mask
if tokenizer.eos_token_id is not None:
if (next_token == tokenizer.eos_token_id).all():
finished |= (next_token == tokenizer.eos_token_id)
# 4. break only if everyone is finished
if finished.all():
break
# Embed the generated token and append to prefix
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
next_token_unsqueezed = next_token.unsqueeze(1)
def next_token_embed_func(next_token_unsqueezed):
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
next_emb_dim = next_emb.shape[-1]
return next_emb * math.sqrt(next_emb_dim)
return next_emb * math.sqrt(next_emb.shape[-1])
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
# Append to prefix embeddings
# update embeddings
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
# Update masks - new token is valid and uses causal attention
# update padding masks
prefix_pad_masks = torch.cat([
prefix_pad_masks,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1)
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# update attention masks
old_seq_len = prefix_att_masks.shape[1]
new_seq_len = old_seq_len + 1
new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device)
new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks
new_att_masks[:, -1, :] = prefix_pad_masks
prefix_att_masks = new_att_masks
return generated_tokens
@@ -978,13 +1334,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
# During inference, we don't have subtask_tokens yet, so pass None
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks
# Also no FAST tokens during inference
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# prefix_att_masks is already a 2D attention mask from embed_prefix
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -1209,7 +1575,7 @@ class PI05Policy(PreTrainedPolicy):
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
@@ -1442,19 +1808,26 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Decode and print ground truth subtask tokens during training
if self.tokenizer is not None and self.training:
bsize = subtask_tokens.shape[0]
for i in range(bsize):
# Remove padding tokens (0) and special tokens
valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
if len(valid_tokens) > 0:
decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# print(f"[Training] Ground truth subtask {i}: {decoded_text}")
# if self.tokenizer is not None and self.training:
# bsize = subtask_tokens.shape[0]
# for i in range(bsize):
# # Remove padding tokens (0) and special tokens
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
# # if len(valid_tokens) > 0:
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
# Get FAST action tokens from batch
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
# Compute loss (no separate state needed for PI05)
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions)
# fast_action_tokens = discrete action token IDs to predict
loss_dict = self.model.forward(
images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions,
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
)
# Extract the total loss
loss = loss_dict["loss"]
@@ -1464,6 +1837,7 @@ class PI05Policy(PreTrainedPolicy):
"loss": loss.item(),
"flow_loss": loss_dict["flow_loss"].mean().item(),
"subtask_loss": loss_dict["subtask_loss"].item(),
"fast_loss": loss_dict["fast_loss"].item(),
}
return loss, detailed_loss_dict
+4 -1
View File
@@ -33,6 +33,7 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
ActionTokenizerProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
@@ -158,7 +159,6 @@ def make_pi05_pre_post_processors(
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
@@ -177,6 +177,9 @@ def make_pi05_pre_post_processors(
padding_side="right",
padding="max_length",
),
ActionTokenizerProcessorStep(
tokenizer_name="physical-intelligence/fast",
),
DeviceProcessorStep(device=config.device),
]
+1 -1
View File
@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
RobotActionToPolicyActionProcessorStep,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import TokenizerProcessorStep
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
__all__ = [
"ActionProcessorStep",
+241 -5
View File
@@ -15,10 +15,13 @@
# limitations under the License.
"""
This script defines a processor for tokenizing natural language instructions from an environment transition.
This script defines processors for tokenizing data from an environment transition.
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
token IDs and attention masks, which are then added to the observation dictionary.
It includes:
- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert
task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary.
- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer)
to tokenize action tensors into discrete token IDs for action modeling.
"""
from __future__ import annotations
@@ -30,6 +33,8 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
@@ -40,12 +45,13 @@ from lerobot.utils.constants import (
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer
else:
AutoProcessor = None
AutoTokenizer = None
@@ -423,3 +429,233 @@ class TokenizerProcessorStep(ObservationProcessorStep):
)
return features
@dataclass
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
Processor step to tokenize action data using a fast action tokenizer.
This step takes action tensors from an `EnvTransition`, tokenizes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the tokenized action.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
max_action_tokens: int = 32
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action tokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies action tokenization to the transition.
This overrides the base class to handle both tokens and mask.
Args:
transition: The input transition with action data.
Returns:
The processed transition with tokenized actions and mask in complementary data.
"""
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
if action is None:
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
Args:
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
Returns:
A tuple of (tokens, mask) where:
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
"""
if action is None:
raise ValueError("Action cannot be None")
# Get the device and dtype of the input action
device = action.device if isinstance(action, torch.Tensor) else None
# Handle single sample (add batch dimension)
single_sample = action.dim() == 1
if single_sample:
action = action.unsqueeze(0)
batch_size = action.shape[0]
# Tokenize the action batch
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
action_cpu = action[i:i+1].cpu()
tokens = self.action_tokenizer(action_cpu)
# Convert to numpy array if it's a list
if isinstance(tokens, list):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
elif not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
else:
# Move tokens back to the same device as input action
tokens = tokens.to(device=action.device)
# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
tokens = tokens[:self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
mask = torch.cat([
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device)
])
# Pad tokens with zeros
tokens = torch.nn.functional.pad(
tokens,
(0, self.max_action_tokens - len(tokens)),
value=0
)
tokens_list.append(tokens)
masks_list.append(mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
return tokens_batch, masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"max_action_tokens": self.max_action_tokens,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect tokenized actions.
This updates the policy features dictionary to indicate that the action
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the tokenized shape
# The action is now a sequence of token IDs
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the tokenized version
features[PipelineFeatureType.ACTION] = {
ACTION_TOKENS: PolicyFeature(
type=FeatureType.SEQUENCE, # Token sequence
shape=(self.max_action_tokens,)
)
}
return features
+1
View File
@@ -90,6 +90,7 @@ def update_policy(
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
action = policy.select_action(batch)
breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
+2
View File
@@ -33,6 +33,8 @@ OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
ACTION = "action"
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"