mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
add fast tokenizer support
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
|
||||
|
||||
This example shows how to:
|
||||
1. Load a dataset with action data
|
||||
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
|
||||
3. Access both the tokenized actions and the attention mask
|
||||
4. Decode tokenized actions back to their original form
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK
|
||||
|
||||
# Define delta timestamps for the dataset
|
||||
delta_timestamps = {
|
||||
'action': [
|
||||
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
|
||||
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
|
||||
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
|
||||
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
|
||||
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
|
||||
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
|
||||
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
|
||||
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
|
||||
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
|
||||
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
|
||||
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
|
||||
]
|
||||
}
|
||||
|
||||
# Load the dataset
|
||||
print("Loading dataset...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="local",
|
||||
root="/fsx/jade_choghari/outputs/pgen_annotations1",
|
||||
delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
# Create a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Get a batch of data
|
||||
batch = next(iter(dataloader))
|
||||
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
|
||||
|
||||
print(f"\nOriginal action shape: {action_data.shape}")
|
||||
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
|
||||
|
||||
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
|
||||
print("\n" + "="*80)
|
||||
print("Method 1: Direct tokenizer usage")
|
||||
print("="*80)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize directly
|
||||
tokens = tokenizer(action_data)
|
||||
print(f"\nDirect tokenization result type: {type(tokens)}")
|
||||
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
|
||||
|
||||
# Decode
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print(f"Decoded actions shape: {decoded_actions.shape}")
|
||||
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
|
||||
|
||||
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
|
||||
print("\n" + "="*80)
|
||||
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
|
||||
print("="*80)
|
||||
|
||||
# Create the action tokenizer processor step
|
||||
action_tokenizer_processor = ActionTokenizerProcessorStep(
|
||||
tokenizer_name="physical-intelligence/fast",
|
||||
trust_remote_code=True,
|
||||
max_action_tokens=32, # Maximum number of tokens per action
|
||||
)
|
||||
|
||||
# Create a transition with the action data
|
||||
transition = {
|
||||
TransitionKey.ACTION: action_data,
|
||||
TransitionKey.OBSERVATION: {}, # Empty for this example
|
||||
}
|
||||
|
||||
# Apply the processor
|
||||
processed_transition = action_tokenizer_processor(transition)
|
||||
|
||||
# Extract tokenized actions and mask
|
||||
tokenized_actions = processed_transition[TransitionKey.ACTION]
|
||||
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
action_mask = complementary_data[ACTION_TOKEN_MASK]
|
||||
|
||||
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
|
||||
print(f"Action mask dtype: {action_mask.dtype}")
|
||||
|
||||
# Show token statistics
|
||||
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
|
||||
print(f"First sample mask: {action_mask[0]}")
|
||||
num_real_tokens = action_mask[0].sum().item()
|
||||
print(f"Number of real tokens (non-padding): {num_real_tokens}")
|
||||
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
|
||||
|
||||
# Decode using the mask
|
||||
print("\nDecoding tokenized actions...")
|
||||
decoded_with_processor = tokenizer.decode(tokenized_actions)
|
||||
print(f"Decoded actions shape: {decoded_with_processor.shape}")
|
||||
|
||||
# Calculate reconstruction error
|
||||
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
|
||||
|
||||
# Show that masking works correctly
|
||||
print("\n" + "="*80)
|
||||
print("Mask demonstration")
|
||||
print("="*80)
|
||||
for i in range(min(4, tokenized_actions.shape[0])):
|
||||
mask_i = action_mask[i]
|
||||
num_real = mask_i.sum().item()
|
||||
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("Action tokenization example completed successfully!")
|
||||
print("="*80)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# Load the tokenizer from the Hugging Face hub
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize & decode action chunks (we use dummy data here)
|
||||
action_data = batch["action"] # one batch of action chunks
|
||||
tokens = tokenizer(action_data) # tokens = list[int]
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print("tokenized actions: ", tokens)
|
||||
@@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
@@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader(
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
batch = pre_processor(batch)
|
||||
breakpoint()
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
## One-sentence answer
|
||||
|
||||
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
|
||||
|
||||
Everything else you’ve seen so far was just metadata.
|
||||
|
||||
---
|
||||
|
||||
## What goes in
|
||||
|
||||
### Inputs
|
||||
|
||||
```python
|
||||
prefix_pad_masks # shape [B, L]
|
||||
prefix_att_masks # shape [B, L]
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
* `prefix_pad_masks[b, i] = True`
|
||||
→ token `i` exists (not padding)
|
||||
|
||||
* `prefix_att_masks[b, i] = False`
|
||||
→ token `i` is **bidirectional**
|
||||
|
||||
* `prefix_att_masks[b, i] = True`
|
||||
→ token `i` is **causal (autoregressive)**
|
||||
|
||||
---
|
||||
|
||||
## What comes out
|
||||
|
||||
```python
|
||||
att_2d_prefix # shape [B, L, L]
|
||||
```
|
||||
|
||||
Each entry:
|
||||
|
||||
```text
|
||||
att_2d_prefix[b, i, j] = True
|
||||
```
|
||||
|
||||
means:
|
||||
|
||||
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
|
||||
|
||||
---
|
||||
|
||||
## How it is constructed (conceptually)
|
||||
|
||||
For **each batch b**, **each query position i**, **each key position j**:
|
||||
|
||||
```python
|
||||
if not prefix_pad_masks[b, j]:
|
||||
att[b, i, j] = False # cannot attend to padding
|
||||
else if not prefix_att_masks[b, i]:
|
||||
att[b, i, j] = True # bidirectional token → can see all real tokens
|
||||
else:
|
||||
att[b, i, j] = (j <= i) # causal token → can see only past + itself
|
||||
```
|
||||
|
||||
That’s it.
|
||||
|
||||
---
|
||||
|
||||
## Tiny concrete example (exactly matching your code)
|
||||
|
||||
Suppose:
|
||||
|
||||
```python
|
||||
prefix_pad_masks[0] = [T, T, T, T, T, F]
|
||||
prefix_att_masks[0] = [F, F, F, T, T, T]
|
||||
```
|
||||
|
||||
Tokens:
|
||||
|
||||
```
|
||||
0: IMG
|
||||
1: IMG
|
||||
2: LANG
|
||||
3: SUB0
|
||||
4: SUB1
|
||||
5: PAD
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resulting `att_2d_prefix[0]`
|
||||
|
||||
`✓ = True, ✗ = False`
|
||||
|
||||
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|
||||
| ---------- | - | - | - | - | - | - |
|
||||
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
|
||||
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||
|
||||
---
|
||||
|
||||
## Why this matters for your training code
|
||||
|
||||
This line:
|
||||
|
||||
```python
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
|
||||
```
|
||||
|
||||
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
|
||||
|
||||
This is **exactly what Paligemma uses inside self-attention**.
|
||||
|
||||
---
|
||||
|
||||
## Key implications (VERY important)
|
||||
|
||||
### 1️⃣ This mask does **not isolate token groups**
|
||||
|
||||
* Bidirectional tokens can attend to **everything**
|
||||
* Causal tokens only restrict *their own row*
|
||||
|
||||
So **flow/action tokens must be blocked separately**.
|
||||
|
||||
---
|
||||
|
||||
### 2️⃣ This is why your AR subtask prediction works
|
||||
|
||||
* Subtask tokens are causal
|
||||
* Output at position `i` predicts token `i+1`
|
||||
* Padding is fully ignored
|
||||
|
||||
---
|
||||
|
||||
### 3️⃣ Inference behavior
|
||||
|
||||
When `subtask_tokens = None`:
|
||||
|
||||
* `prefix_att_masks` contains only `False`
|
||||
* `att_2d_prefix` becomes **fully bidirectional**
|
||||
* No AR behavior remains
|
||||
|
||||
Exactly what you want.
|
||||
|
||||
---
|
||||
|
||||
## One-sentence takeaway (commit this)
|
||||
|
||||
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
|
||||
|
||||
If you want next, I can:
|
||||
|
||||
* inspect `make_att_2d_masks()` source with you
|
||||
* show how to block **flow → subtask** attention
|
||||
* explain how this changes when suffix tokens are added
|
||||
* help you refactor this into a cleaner “grouped attention” API
|
||||
|
||||
You’re now at the point where the model’s behavior should feel *predictable*, not magical.
|
||||
@@ -0,0 +1 @@
|
||||
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
|
||||
@@ -1,64 +0,0 @@
|
||||
|
||||
Fine tune output
|
||||
(Pdb) images[2].mean()
|
||||
tensor(-1., device='cuda:0')
|
||||
(Pdb) images[1].mean()
|
||||
tensor(-0.5780, device='cuda:0')
|
||||
(Pdb) images[0].mean()
|
||||
tensor(-0.7716, device='cuda:0')
|
||||
(Pdb) (Pdb) high_level_task[0]
|
||||
tensor([ 2, 7978, 2403, 6911, 235292, 5651, 3124, 573, 18571,
|
||||
7762, 6643, 573, 9010, 72993, 21810, 4894, 3040, 235292,
|
||||
235248, 235274, 235274, 235274, 728, 235274, 235248, 235284, 235308,
|
||||
235308, 235248, 235274, 235318, 235315, 235248, 235274, 235310, 235318,
|
||||
235248, 235284, 235318, 235248, 235274, 235284, 235321, 235248, 235274,
|
||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
|
||||
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
|
||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
|
||||
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
|
||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235289,
|
||||
4284, 8277, 235292, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0], device='cuda:0')
|
||||
(Pdb) subtask_tokens[0]
|
||||
tensor([ 2, 28040, 7762, 14574, 6643, 9010, 37901, 21810, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
device='cuda:0')
|
||||
(Pdb) actions.shape
|
||||
torch.Size([4, 50, 32])
|
||||
(Pdb) actions.mean()
|
||||
tensor(0.0143, device='cuda:0')
|
||||
(Pdb)
|
||||
|
||||
|
||||
|
||||
|
||||
Inference:
|
||||
@@ -37,6 +37,9 @@ class PI05Config(PreTrainedConfig):
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 32
|
||||
fast_vocab_size: int = 2048
|
||||
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
|
||||
@@ -537,6 +537,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
# FAST action token embedding and prediction head
|
||||
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
# Apply dtype conversion to FAST layers to match model precision
|
||||
if config.dtype == "bfloat16":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
elif config.dtype == "float32":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
@@ -592,6 +604,194 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
|
||||
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
||||
"""Create custom 2D attention mask for the new attention pattern.
|
||||
|
||||
Attention rules:
|
||||
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
- FAST: attend to images + language + subtask, causal among themselves
|
||||
|
||||
Args:
|
||||
att_mask_segments: List of (type, length) tuples
|
||||
pad_masks: Padding masks [B, total_seq_len]
|
||||
bsize: Batch size
|
||||
|
||||
Returns:
|
||||
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
||||
"""
|
||||
total_len = sum(length for _, length in att_mask_segments)
|
||||
device = pad_masks.device
|
||||
|
||||
# Initialize attention mask as False (cannot attend)
|
||||
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Track positions for each segment
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
# Apply attention rules
|
||||
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||
# Images and Language can attend to each other bidirectionally
|
||||
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend to images + language
|
||||
elif query_type == 'subtask' and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend causally to themselves
|
||||
elif query_type == 'subtask' and key_type == 'subtask':
|
||||
# Create causal mask for subtask tokens
|
||||
subtask_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# FAST tokens attend to images + language + subtask
|
||||
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# FAST tokens attend causally to themselves
|
||||
elif query_type == 'fast' and key_type == 'fast':
|
||||
fast_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# Apply padding masks
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
|
||||
return att_2d_masks
|
||||
|
||||
def visualize_attention_mask(
|
||||
self,
|
||||
att_mask_segments,
|
||||
att_2d_masks,
|
||||
save_path,
|
||||
batch_idx=0,
|
||||
dpi=150,
|
||||
max_display_tokens=None
|
||||
):
|
||||
"""Visualize the attention mask with labeled segments.
|
||||
|
||||
Args:
|
||||
att_mask_segments: List of (type, length) tuples defining the segments
|
||||
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
|
||||
save_path: Path where to save the visualization image
|
||||
batch_idx: Which batch item to visualize (default: 0)
|
||||
dpi: DPI for the saved image (default: 150)
|
||||
max_display_tokens: Maximum number of tokens to display (for very long sequences)
|
||||
"""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
except ImportError:
|
||||
logging.warning("matplotlib not available, skipping attention mask visualization")
|
||||
return
|
||||
|
||||
# Extract the mask for the specified batch
|
||||
mask = att_2d_masks[batch_idx].cpu().float().numpy()
|
||||
|
||||
# If sequence is too long, downsample for visualization
|
||||
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
|
||||
# Simple downsampling by taking every Nth token
|
||||
step = mask.shape[0] // max_display_tokens
|
||||
mask = mask[::step, ::step]
|
||||
# Adjust segments accordingly
|
||||
att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments]
|
||||
|
||||
# Calculate positions for each segment
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
|
||||
# Create custom colormap: white for False (no attention), blue for True (attention)
|
||||
colors = ['white', '#2E86AB']
|
||||
n_bins = 2
|
||||
cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins)
|
||||
|
||||
# Display the mask
|
||||
im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
||||
cbar.set_label('Attention Enabled', rotation=270, labelpad=20)
|
||||
cbar.set_ticks([0.25, 0.75])
|
||||
cbar.set_ticklabels(['No', 'Yes'])
|
||||
|
||||
# Define colors for each segment type
|
||||
segment_colors = {
|
||||
'image': '#A23B72',
|
||||
'language': '#F18F01',
|
||||
'subtask': '#C73E1D',
|
||||
'fast': '#6A994E'
|
||||
}
|
||||
|
||||
# Draw segment boundaries and labels
|
||||
for seg_type, start, end in positions:
|
||||
color = segment_colors.get(seg_type, '#666666')
|
||||
|
||||
# Draw vertical lines for columns (keys)
|
||||
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
|
||||
# Draw horizontal lines for rows (queries)
|
||||
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
|
||||
# Add labels at the top
|
||||
mid_pos = (start + end) / 2
|
||||
ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})",
|
||||
ha='center', va='top', fontsize=10, fontweight='bold', color=color)
|
||||
|
||||
# Add labels on the left
|
||||
ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})",
|
||||
ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0)
|
||||
|
||||
# Set axis labels
|
||||
ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold')
|
||||
ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)',
|
||||
fontsize=14, fontweight='bold', pad=20)
|
||||
|
||||
# Create legend for segment types
|
||||
legend_patches = []
|
||||
attention_rules = {
|
||||
'image': 'Bidirectional with lang',
|
||||
'language': 'Bidirectional with images',
|
||||
'subtask': 'Attends to img+lang, causal self',
|
||||
'fast': 'Attends to all, causal self'
|
||||
}
|
||||
for seg_type, color in segment_colors.items():
|
||||
if any(seg[0] == seg_type for seg in att_mask_segments):
|
||||
rule = attention_rules.get(seg_type, '')
|
||||
legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}'))
|
||||
|
||||
ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0),
|
||||
framealpha=0.9, fontsize=9)
|
||||
|
||||
# Adjust layout and save
|
||||
plt.tight_layout()
|
||||
|
||||
# Ensure the directory exists
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logging.info(f"Attention mask visualization saved to: {save_path}")
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
@@ -609,8 +809,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
subtask_tokens,
|
||||
masks,
|
||||
subtask_masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||
|
||||
Args:
|
||||
@@ -619,17 +827,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
tokens: Language instruction tokens
|
||||
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||
masks: Attention masks for tokens
|
||||
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
||||
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
||||
|
||||
Returns:
|
||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
|
||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
||||
pad_masks: Padding masks
|
||||
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
|
||||
att_masks: Custom 2D attention mask implementing the required pattern
|
||||
total_T_images: Total number of image tokens
|
||||
num_subtask_embs: Number of subtask token embeddings
|
||||
num_fast_embs: Number of FAST action token embeddings
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||
total_T_images = 0
|
||||
num_subtask_embs = 0
|
||||
num_fast_embs = 0
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
@@ -642,7 +856,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
||||
att_mask_segments.append(('image', num_img_embs))
|
||||
total_T_images += num_img_embs
|
||||
|
||||
# Process language instruction tokens
|
||||
@@ -656,7 +870,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
pad_masks.append(masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
|
||||
att_mask_segments.append(('language', num_lang_embs))
|
||||
|
||||
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
||||
if subtask_tokens is not None:
|
||||
@@ -672,18 +886,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
pad_masks.append(subtask_masks)
|
||||
|
||||
num_subtask_embs = subtask_emb.shape[1]
|
||||
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
|
||||
# and previous subtask tokens
|
||||
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
|
||||
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||
# Process FAST action tokens if provided (these are discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.fast_action_embedding(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
|
||||
# Use provided mask or create default (all valid)
|
||||
if fast_action_masks is not None:
|
||||
fast_pad_mask = fast_action_masks
|
||||
else:
|
||||
bsize = fast_action_tokens.shape[0]
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
|
||||
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
pad_masks.append(fast_pad_mask)
|
||||
att_mask_segments.append(('fast', num_fast_embs))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
|
||||
# Create custom 2D attention mask
|
||||
# Attention rules:
|
||||
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
# - FAST: attend to images + language + subtask, causal among themselves
|
||||
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
||||
|
||||
return embs, pad_masks, att_masks, total_T_images
|
||||
# # Optionally visualize the attention mask
|
||||
# self.visualize_attention_mask(
|
||||
# att_mask_segments=att_mask_segments,
|
||||
# att_2d_masks=att_masks,
|
||||
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
||||
# batch_idx=0,
|
||||
# max_display_tokens=512 # Limit display for very long sequences
|
||||
# )
|
||||
|
||||
return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
@@ -732,8 +977,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
|
||||
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
|
||||
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks)
|
||||
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss.
|
||||
|
||||
Args:
|
||||
@@ -743,7 +988,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
high_level_task_masks: Attention masks for high_level_task
|
||||
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
|
||||
subtask_masks: Attention masks for subtask_tokens
|
||||
actions: Ground truth actions
|
||||
actions: Ground truth actions [B, chunk_size, action_dim]
|
||||
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
|
||||
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
|
||||
noise: Optional noise for flow matching
|
||||
time: Optional time for flow matching
|
||||
"""
|
||||
@@ -757,39 +1004,45 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# Embed prefix (images + high_level_task + subtask_tokens)
|
||||
# Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks
|
||||
# Initialize FAST loss to 0 (will be computed only if FAST tokens are provided)
|
||||
fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype)
|
||||
|
||||
# ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ==========
|
||||
# Only run this pass if FAST action tokens are provided
|
||||
if fast_action_tokens is not None and fast_action_masks is not None:
|
||||
# Embed prefix (images + high_level_task + subtask_tokens + FAST tokens)
|
||||
# FAST tokens are provided as discrete token IDs
|
||||
prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||
)
|
||||
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# Prepare attention masks for prefix-only pass (for subtask token prediction)
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
# Prepare attention masks for prefix pass with FAST tokens
|
||||
position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype)
|
||||
|
||||
# prefix-only transformer run for subtask token prediction
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
# Forward pass through paligemma for subtask + FAST prediction
|
||||
(prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_with_fast_4d,
|
||||
position_ids=position_ids_prefix_with_fast,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None], # SUFFIX = None
|
||||
inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# LM HEAD → SUBTASK LOGITS
|
||||
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab)
|
||||
|
||||
# Extract logits for subtask token prediction
|
||||
# In autoregressive modeling, output at position i predicts token at position i+1
|
||||
# So we take logits from one position earlier:
|
||||
# - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0]
|
||||
# - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc.
|
||||
T_high_level_task = high_level_task.size(1)
|
||||
T_subtask = subtask_tokens.size(1)
|
||||
start_index = total_T_images + T_high_level_task
|
||||
@@ -797,35 +1050,137 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
|
||||
|
||||
targets = subtask_tokens # (B, T_subtask)
|
||||
# Compute cross-entropy loss
|
||||
# Compute cross-entropy loss for subtask
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
# Reshape for loss computation
|
||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
|
||||
targets_flat = targets.reshape(-1) # (B*T_subtask)
|
||||
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
|
||||
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
|
||||
|
||||
# Apply mask and compute mean loss over valid tokens
|
||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||
targets_flat = targets.reshape(-1)
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||
masked_loss = loss_per_token * subtask_masks.float()
|
||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||
|
||||
# Extract outputs for FAST action token prediction and compute auxiliary loss
|
||||
# FAST outputs start after subtask tokens
|
||||
# Similar to subtask, we use autoregressive prediction where position i predicts token i+1
|
||||
fast_start_index = end_index
|
||||
fast_end_index = fast_start_index + num_fast_embs
|
||||
|
||||
# Get logits for FAST action tokens using the FAST LM head
|
||||
fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size)
|
||||
|
||||
# Extract logits for FAST token prediction (autoregressive: position i predicts token i+1)
|
||||
# - Position (fast_start_index-1) predicts fast_action_tokens[0]
|
||||
# - Position (fast_start_index) predicts fast_action_tokens[1], etc.
|
||||
fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size)
|
||||
|
||||
# Compute cross-entropy loss for FAST action tokens
|
||||
fast_targets = fast_action_tokens # (B, max_action_tokens)
|
||||
loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size)
|
||||
fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens)
|
||||
|
||||
fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens)
|
||||
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens)
|
||||
|
||||
# Apply mask and compute mean loss over valid tokens
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
else:
|
||||
# If no FAST tokens provided, compute subtask loss without FAST tokens
|
||||
# This is the fallback for backward compatibility
|
||||
prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16)
|
||||
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype)
|
||||
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs_for_subtask, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_out)
|
||||
|
||||
T_high_level_task = high_level_task.size(1)
|
||||
T_subtask = subtask_tokens.size(1)
|
||||
start_index = total_T_images + T_high_level_task
|
||||
end_index = start_index + T_subtask
|
||||
logits_subtask = logits[:, start_index-1:end_index-1, :]
|
||||
|
||||
targets = subtask_tokens
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||
targets_flat = targets.reshape(-1)
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||
masked_loss = loss_per_token * subtask_masks.float()
|
||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||
|
||||
# ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ==========
|
||||
# Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens)
|
||||
prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
)
|
||||
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16)
|
||||
|
||||
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
# For the flow matching pass, we need custom attention where:
|
||||
# - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention)
|
||||
# - suffix attends to all prefix + causal to itself
|
||||
# We'll construct this by extending prefix_att_masks_no_fast to include suffix
|
||||
|
||||
# prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len]
|
||||
# We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len]
|
||||
|
||||
bsize = prefix_pad_masks_no_fast.shape[0]
|
||||
prefix_len = prefix_pad_masks_no_fast.shape[1]
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
total_len = prefix_len + suffix_len
|
||||
device = prefix_pad_masks_no_fast.device
|
||||
|
||||
# Create full attention mask
|
||||
full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Copy prefix attention pattern
|
||||
full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast
|
||||
|
||||
# Suffix attends to all prefix
|
||||
full_att_2d_masks[:, prefix_len:, :prefix_len] = True
|
||||
|
||||
# Suffix has causal attention among itself
|
||||
suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device))
|
||||
full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :]
|
||||
|
||||
# Apply padding masks
|
||||
pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1)
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
full_att_2d_masks = full_att_2d_masks & pad_2d_masks
|
||||
|
||||
# Prepare attention masks for full forward pass (prefix + suffix)
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
@@ -836,11 +1191,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
# prefix_out to be used for the language head
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
@@ -856,78 +1210,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return {
|
||||
"flow_loss": fm_loss,
|
||||
"subtask_loss": subtask_loss,
|
||||
"loss": 10 * fm_loss.mean() + subtask_loss,
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_subtask_tokens(
|
||||
self, images, img_masks, tokens, masks, tokenizer, max_length, device
|
||||
):
|
||||
"""Generate subtask tokens autoregressively using next token prediction."""
|
||||
bsize = tokens.shape[0]
|
||||
|
||||
# Get lm_head for token generation
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# Embed prefix without subtask tokens first
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
)
|
||||
|
||||
# Initialize generated tokens list - start with BOS token or first token after instruction
|
||||
# For PaliGemma, we'll start generation and accumulate tokens
|
||||
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
||||
|
||||
for t in range(max_length):
|
||||
# Prepare attention masks for current prefix
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
# tracking mask: False = still generating, True = finished
|
||||
finished = torch.zeros(bsize, dtype=torch.bool, device=device)
|
||||
|
||||
for t in range(max_length):
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
# Forward pass through model to get logits
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
# ...
|
||||
)
|
||||
|
||||
# Get logits from the last position
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
next_token_logits = logits[:, -1, :] # (B, vocab)
|
||||
|
||||
# Greedy decoding - take the most likely token
|
||||
logits = lm_head(prefix_out)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
||||
|
||||
# Store generated token
|
||||
# 1. if a row was already finished, force the next token to be PAD (0)
|
||||
next_token = torch.where(finished, torch.tensor(0, device=device), next_token)
|
||||
|
||||
# 2. store the token
|
||||
generated_tokens[:, t] = next_token
|
||||
|
||||
# Check for EOS token - if all batches have generated EOS, stop
|
||||
# 3. update the finished mask
|
||||
if tokenizer.eos_token_id is not None:
|
||||
if (next_token == tokenizer.eos_token_id).all():
|
||||
finished |= (next_token == tokenizer.eos_token_id)
|
||||
|
||||
# 4. break only if everyone is finished
|
||||
if finished.all():
|
||||
break
|
||||
|
||||
# Embed the generated token and append to prefix
|
||||
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
||||
next_token_unsqueezed = next_token.unsqueeze(1)
|
||||
|
||||
def next_token_embed_func(next_token_unsqueezed):
|
||||
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
||||
next_emb_dim = next_emb.shape[-1]
|
||||
return next_emb * math.sqrt(next_emb_dim)
|
||||
return next_emb * math.sqrt(next_emb.shape[-1])
|
||||
|
||||
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
|
||||
|
||||
# Append to prefix embeddings
|
||||
# update embeddings
|
||||
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
|
||||
|
||||
# Update masks - new token is valid and uses causal attention
|
||||
# update padding masks
|
||||
prefix_pad_masks = torch.cat([
|
||||
prefix_pad_masks,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1)
|
||||
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
# update attention masks
|
||||
old_seq_len = prefix_att_masks.shape[1]
|
||||
new_seq_len = old_seq_len + 1
|
||||
new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device)
|
||||
new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks
|
||||
new_att_masks[:, -1, :] = prefix_pad_masks
|
||||
prefix_att_masks = new_att_masks
|
||||
|
||||
return generated_tokens
|
||||
|
||||
@@ -978,13 +1334,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
|
||||
|
||||
# During inference, we don't have subtask_tokens yet, so pass None
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks
|
||||
# Also no FAST tokens during inference
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# prefix_att_masks is already a 2D attention mask from embed_prefix
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -1209,7 +1575,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
@@ -1442,19 +1808,26 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Decode and print ground truth subtask tokens during training
|
||||
if self.tokenizer is not None and self.training:
|
||||
bsize = subtask_tokens.shape[0]
|
||||
for i in range(bsize):
|
||||
# Remove padding tokens (0) and special tokens
|
||||
valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
||||
if len(valid_tokens) > 0:
|
||||
decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
# print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
||||
# if self.tokenizer is not None and self.training:
|
||||
# bsize = subtask_tokens.shape[0]
|
||||
# for i in range(bsize):
|
||||
# # Remove padding tokens (0) and special tokens
|
||||
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
||||
# # if len(valid_tokens) > 0:
|
||||
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
||||
|
||||
# Get FAST action tokens from batch
|
||||
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
|
||||
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
||||
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
||||
loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions)
|
||||
# fast_action_tokens = discrete action token IDs to predict
|
||||
loss_dict = self.model.forward(
|
||||
images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions,
|
||||
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||
)
|
||||
|
||||
# Extract the total loss
|
||||
loss = loss_dict["loss"]
|
||||
@@ -1464,6 +1837,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
"loss": loss.item(),
|
||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||
"subtask_loss": loss_dict["subtask_loss"].item(),
|
||||
"fast_loss": loss_dict["fast_loss"].item(),
|
||||
}
|
||||
|
||||
return loss, detailed_loss_dict
|
||||
|
||||
@@ -33,6 +33,7 @@ from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
@@ -158,7 +159,6 @@ def make_pi05_pre_post_processors(
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
@@ -177,6 +177,9 @@ def make_pi05_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
ActionTokenizerProcessorStep(
|
||||
tokenizer_name="physical-intelligence/fast",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines a processor for tokenizing natural language instructions from an environment transition.
|
||||
This script defines processors for tokenizing data from an environment transition.
|
||||
|
||||
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
|
||||
token IDs and attention masks, which are then added to the observation dictionary.
|
||||
It includes:
|
||||
- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert
|
||||
task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary.
|
||||
- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer)
|
||||
to tokenize action tensors into discrete token IDs for action modeling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -30,6 +33,8 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||
@@ -40,12 +45,13 @@ from lerobot.utils.constants import (
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
else:
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@@ -423,3 +429,233 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
|
||||
class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to tokenize action data using a fast action tokenizer.
|
||||
|
||||
This step takes action tensors from an `EnvTransition`, tokenizes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the tokenized action.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
max_action_tokens: int = 32
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action tokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies action tokenization to the transition.
|
||||
|
||||
This overrides the base class to handle both tokens and mask.
|
||||
|
||||
Args:
|
||||
transition: The input transition with action data.
|
||||
|
||||
Returns:
|
||||
The processed transition with tokenized actions and mask in complementary data.
|
||||
"""
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
|
||||
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
Args:
|
||||
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, mask) where:
|
||||
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
|
||||
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
|
||||
"""
|
||||
if action is None:
|
||||
raise ValueError("Action cannot be None")
|
||||
|
||||
# Get the device and dtype of the input action
|
||||
device = action.device if isinstance(action, torch.Tensor) else None
|
||||
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = action.dim() == 1
|
||||
if single_sample:
|
||||
action = action.unsqueeze(0)
|
||||
|
||||
batch_size = action.shape[0]
|
||||
|
||||
# Tokenize the action batch
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
action_cpu = action[i:i+1].cpu()
|
||||
tokens = self.action_tokenizer(action_cpu)
|
||||
|
||||
# Convert to numpy array if it's a list
|
||||
if isinstance(tokens, list):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||
elif not isinstance(tokens, torch.Tensor):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||
else:
|
||||
# Move tokens back to the same device as input action
|
||||
tokens = tokens.to(device=action.device)
|
||||
|
||||
# Flatten to 1D if needed
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
tokens = tokens[:self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
mask = torch.cat([
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device)
|
||||
])
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(
|
||||
tokens,
|
||||
(0, self.max_action_tokens - len(tokens)),
|
||||
value=0
|
||||
)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"max_action_tokens": self.max_action_tokens,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect tokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the tokenized shape
|
||||
# The action is now a sequence of token IDs
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the tokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
ACTION_TOKENS: PolicyFeature(
|
||||
type=FeatureType.SEQUENCE, # Token sequence
|
||||
shape=(self.max_action_tokens,)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
@@ -90,6 +90,7 @@ def update_policy(
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
action = policy.select_action(batch)
|
||||
breakpoint()
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
||||
ACTION = "action"
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
Reference in New Issue
Block a user