mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
add fast tokenizer support
This commit is contained in:
@@ -0,0 +1,138 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
|
||||||
|
|
||||||
|
This example shows how to:
|
||||||
|
1. Load a dataset with action data
|
||||||
|
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
|
||||||
|
3. Access both the tokenized actions and the attention mask
|
||||||
|
4. Decode tokenized actions back to their original form
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
|
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
|
||||||
|
from lerobot.utils.constants import ACTION_TOKEN_MASK
|
||||||
|
|
||||||
|
# Define delta timestamps for the dataset
|
||||||
|
delta_timestamps = {
|
||||||
|
'action': [
|
||||||
|
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
|
||||||
|
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
|
||||||
|
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
|
||||||
|
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
|
||||||
|
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
|
||||||
|
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
|
||||||
|
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
|
||||||
|
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
|
||||||
|
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
|
||||||
|
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
|
||||||
|
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
print("Loading dataset...")
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
repo_id="local",
|
||||||
|
root="/fsx/jade_choghari/outputs/pgen_annotations1",
|
||||||
|
delta_timestamps=delta_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dataloader
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get a batch of data
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
|
||||||
|
|
||||||
|
print(f"\nOriginal action shape: {action_data.shape}")
|
||||||
|
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
|
||||||
|
|
||||||
|
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("Method 1: Direct tokenizer usage")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||||
|
|
||||||
|
# Tokenize directly
|
||||||
|
tokens = tokenizer(action_data)
|
||||||
|
print(f"\nDirect tokenization result type: {type(tokens)}")
|
||||||
|
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
decoded_actions = tokenizer.decode(tokens)
|
||||||
|
print(f"Decoded actions shape: {decoded_actions.shape}")
|
||||||
|
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
|
||||||
|
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
|
||||||
|
|
||||||
|
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
# Create the action tokenizer processor step
|
||||||
|
action_tokenizer_processor = ActionTokenizerProcessorStep(
|
||||||
|
tokenizer_name="physical-intelligence/fast",
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_action_tokens=32, # Maximum number of tokens per action
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a transition with the action data
|
||||||
|
transition = {
|
||||||
|
TransitionKey.ACTION: action_data,
|
||||||
|
TransitionKey.OBSERVATION: {}, # Empty for this example
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply the processor
|
||||||
|
processed_transition = action_tokenizer_processor(transition)
|
||||||
|
|
||||||
|
# Extract tokenized actions and mask
|
||||||
|
tokenized_actions = processed_transition[TransitionKey.ACTION]
|
||||||
|
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
action_mask = complementary_data[ACTION_TOKEN_MASK]
|
||||||
|
|
||||||
|
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
|
||||||
|
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
|
||||||
|
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
|
||||||
|
print(f"Action mask dtype: {action_mask.dtype}")
|
||||||
|
|
||||||
|
# Show token statistics
|
||||||
|
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
|
||||||
|
print(f"First sample mask: {action_mask[0]}")
|
||||||
|
num_real_tokens = action_mask[0].sum().item()
|
||||||
|
print(f"Number of real tokens (non-padding): {num_real_tokens}")
|
||||||
|
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
|
||||||
|
|
||||||
|
# Decode using the mask
|
||||||
|
print("\nDecoding tokenized actions...")
|
||||||
|
decoded_with_processor = tokenizer.decode(tokenized_actions)
|
||||||
|
print(f"Decoded actions shape: {decoded_with_processor.shape}")
|
||||||
|
|
||||||
|
# Calculate reconstruction error
|
||||||
|
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
|
||||||
|
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
|
||||||
|
|
||||||
|
# Show that masking works correctly
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("Mask demonstration")
|
||||||
|
print("="*80)
|
||||||
|
for i in range(min(4, tokenized_actions.shape[0])):
|
||||||
|
mask_i = action_mask[i]
|
||||||
|
num_real = mask_i.sum().item()
|
||||||
|
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
|
||||||
|
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("Action tokenization example completed successfully!")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import numpy as np
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
import torch
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
|
|
||||||
|
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||||
|
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
# Load the tokenizer from the Hugging Face hub
|
||||||
|
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||||
|
|
||||||
|
# Tokenize & decode action chunks (we use dummy data here)
|
||||||
|
action_data = batch["action"] # one batch of action chunks
|
||||||
|
tokens = tokenizer(action_data) # tokens = list[int]
|
||||||
|
decoded_actions = tokenizer.decode(tokens)
|
||||||
|
print("tokenized actions: ", tokens)
|
||||||
@@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config
|
|||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(
|
cfg = PreTrainedConfig.from_pretrained(
|
||||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||||
)
|
)
|
||||||
cfg.dtype = "bfloat16"
|
cfg.dtype = "bfloat16"
|
||||||
|
|
||||||
pre_processor, post_processor = make_pre_post_processors(
|
pre_processor, post_processor = make_pre_post_processors(
|
||||||
policy_cfg=cfg,
|
policy_cfg=cfg,
|
||||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||||
)
|
)
|
||||||
|
|
||||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||||
@@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
batch = pre_processor(batch)
|
batch = pre_processor(batch)
|
||||||
breakpoint()
|
|
||||||
policy.train()
|
policy.train()
|
||||||
# run inference
|
# run inference
|
||||||
# action = policy.select_action(batch)
|
# action = policy.select_action(batch)
|
||||||
loss, loss_dict = policy.forward(batch)
|
loss, loss_dict = policy.forward(batch)
|
||||||
|
breakpoint()
|
||||||
# import requests
|
# import requests
|
||||||
# from PIL import Image
|
# from PIL import Image
|
||||||
# from transformers import AutoProcessor
|
# from transformers import AutoProcessor
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
## One-sentence answer
|
||||||
|
|
||||||
|
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
|
||||||
|
|
||||||
|
Everything else you’ve seen so far was just metadata.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What goes in
|
||||||
|
|
||||||
|
### Inputs
|
||||||
|
|
||||||
|
```python
|
||||||
|
prefix_pad_masks # shape [B, L]
|
||||||
|
prefix_att_masks # shape [B, L]
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
|
||||||
|
* `prefix_pad_masks[b, i] = True`
|
||||||
|
→ token `i` exists (not padding)
|
||||||
|
|
||||||
|
* `prefix_att_masks[b, i] = False`
|
||||||
|
→ token `i` is **bidirectional**
|
||||||
|
|
||||||
|
* `prefix_att_masks[b, i] = True`
|
||||||
|
→ token `i` is **causal (autoregressive)**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What comes out
|
||||||
|
|
||||||
|
```python
|
||||||
|
att_2d_prefix # shape [B, L, L]
|
||||||
|
```
|
||||||
|
|
||||||
|
Each entry:
|
||||||
|
|
||||||
|
```text
|
||||||
|
att_2d_prefix[b, i, j] = True
|
||||||
|
```
|
||||||
|
|
||||||
|
means:
|
||||||
|
|
||||||
|
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How it is constructed (conceptually)
|
||||||
|
|
||||||
|
For **each batch b**, **each query position i**, **each key position j**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if not prefix_pad_masks[b, j]:
|
||||||
|
att[b, i, j] = False # cannot attend to padding
|
||||||
|
else if not prefix_att_masks[b, i]:
|
||||||
|
att[b, i, j] = True # bidirectional token → can see all real tokens
|
||||||
|
else:
|
||||||
|
att[b, i, j] = (j <= i) # causal token → can see only past + itself
|
||||||
|
```
|
||||||
|
|
||||||
|
That’s it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tiny concrete example (exactly matching your code)
|
||||||
|
|
||||||
|
Suppose:
|
||||||
|
|
||||||
|
```python
|
||||||
|
prefix_pad_masks[0] = [T, T, T, T, T, F]
|
||||||
|
prefix_att_masks[0] = [F, F, F, T, T, T]
|
||||||
|
```
|
||||||
|
|
||||||
|
Tokens:
|
||||||
|
|
||||||
|
```
|
||||||
|
0: IMG
|
||||||
|
1: IMG
|
||||||
|
2: LANG
|
||||||
|
3: SUB0
|
||||||
|
4: SUB1
|
||||||
|
5: PAD
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Resulting `att_2d_prefix[0]`
|
||||||
|
|
||||||
|
`✓ = True, ✗ = False`
|
||||||
|
|
||||||
|
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|
||||||
|
| ---------- | - | - | - | - | - | - |
|
||||||
|
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||||
|
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||||
|
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||||
|
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
|
||||||
|
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||||
|
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why this matters for your training code
|
||||||
|
|
||||||
|
This line:
|
||||||
|
|
||||||
|
```python
|
||||||
|
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
|
||||||
|
```
|
||||||
|
|
||||||
|
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
|
||||||
|
|
||||||
|
This is **exactly what Paligemma uses inside self-attention**.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key implications (VERY important)
|
||||||
|
|
||||||
|
### 1️⃣ This mask does **not isolate token groups**
|
||||||
|
|
||||||
|
* Bidirectional tokens can attend to **everything**
|
||||||
|
* Causal tokens only restrict *their own row*
|
||||||
|
|
||||||
|
So **flow/action tokens must be blocked separately**.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2️⃣ This is why your AR subtask prediction works
|
||||||
|
|
||||||
|
* Subtask tokens are causal
|
||||||
|
* Output at position `i` predicts token `i+1`
|
||||||
|
* Padding is fully ignored
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3️⃣ Inference behavior
|
||||||
|
|
||||||
|
When `subtask_tokens = None`:
|
||||||
|
|
||||||
|
* `prefix_att_masks` contains only `False`
|
||||||
|
* `att_2d_prefix` becomes **fully bidirectional**
|
||||||
|
* No AR behavior remains
|
||||||
|
|
||||||
|
Exactly what you want.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## One-sentence takeaway (commit this)
|
||||||
|
|
||||||
|
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
|
||||||
|
|
||||||
|
If you want next, I can:
|
||||||
|
|
||||||
|
* inspect `make_att_2d_masks()` source with you
|
||||||
|
* show how to block **flow → subtask** attention
|
||||||
|
* explain how this changes when suffix tokens are added
|
||||||
|
* help you refactor this into a cleaner “grouped attention” API
|
||||||
|
|
||||||
|
You’re now at the point where the model’s behavior should feel *predictable*, not magical.
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
|
|
||||||
Fine tune output
|
|
||||||
(Pdb) images[2].mean()
|
|
||||||
tensor(-1., device='cuda:0')
|
|
||||||
(Pdb) images[1].mean()
|
|
||||||
tensor(-0.5780, device='cuda:0')
|
|
||||||
(Pdb) images[0].mean()
|
|
||||||
tensor(-0.7716, device='cuda:0')
|
|
||||||
(Pdb) (Pdb) high_level_task[0]
|
|
||||||
tensor([ 2, 7978, 2403, 6911, 235292, 5651, 3124, 573, 18571,
|
|
||||||
7762, 6643, 573, 9010, 72993, 21810, 4894, 3040, 235292,
|
|
||||||
235248, 235274, 235274, 235274, 728, 235274, 235248, 235284, 235308,
|
|
||||||
235308, 235248, 235274, 235318, 235315, 235248, 235274, 235310, 235318,
|
|
||||||
235248, 235284, 235318, 235248, 235274, 235284, 235321, 235248, 235274,
|
|
||||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
|
||||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
|
||||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
|
|
||||||
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
|
|
||||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
|
||||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
|
||||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248,
|
|
||||||
235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274,
|
|
||||||
235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284,
|
|
||||||
235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321,
|
|
||||||
235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235289,
|
|
||||||
4284, 8277, 235292, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0], device='cuda:0')
|
|
||||||
(Pdb) subtask_tokens[0]
|
|
||||||
tensor([ 2, 28040, 7762, 14574, 6643, 9010, 37901, 21810, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
||||||
device='cuda:0')
|
|
||||||
(Pdb) actions.shape
|
|
||||||
torch.Size([4, 50, 32])
|
|
||||||
(Pdb) actions.mean()
|
|
||||||
tensor(0.0143, device='cuda:0')
|
|
||||||
(Pdb)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Inference:
|
|
||||||
@@ -37,6 +37,9 @@ class PI05Config(PreTrainedConfig):
|
|||||||
# Shorter state and action vectors will be padded to these dimensions
|
# Shorter state and action vectors will be padded to these dimensions
|
||||||
max_state_dim: int = 32
|
max_state_dim: int = 32
|
||||||
max_action_dim: int = 32
|
max_action_dim: int = 32
|
||||||
|
max_action_tokens: int = 32
|
||||||
|
fast_vocab_size: int = 2048
|
||||||
|
|
||||||
|
|
||||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||||
num_inference_steps: int = 10
|
num_inference_steps: int = 10
|
||||||
|
|||||||
@@ -537,6 +537,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||||
|
|
||||||
|
# FAST action token embedding and prediction head
|
||||||
|
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||||
|
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||||
|
|
||||||
|
# Apply dtype conversion to FAST layers to match model precision
|
||||||
|
if config.dtype == "bfloat16":
|
||||||
|
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||||
|
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||||
|
elif config.dtype == "float32":
|
||||||
|
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||||
|
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||||
|
|
||||||
# Initialize gradient checkpointing flag
|
# Initialize gradient checkpointing flag
|
||||||
self.gradient_checkpointing_enabled = False
|
self.gradient_checkpointing_enabled = False
|
||||||
|
|
||||||
@@ -592,6 +604,194 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
result = result.to(dtype=dtype)
|
result = result.to(dtype=dtype)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
||||||
|
"""Create custom 2D attention mask for the new attention pattern.
|
||||||
|
|
||||||
|
Attention rules:
|
||||||
|
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||||
|
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||||
|
- FAST: attend to images + language + subtask, causal among themselves
|
||||||
|
|
||||||
|
Args:
|
||||||
|
att_mask_segments: List of (type, length) tuples
|
||||||
|
pad_masks: Padding masks [B, total_seq_len]
|
||||||
|
bsize: Batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
||||||
|
"""
|
||||||
|
total_len = sum(length for _, length in att_mask_segments)
|
||||||
|
device = pad_masks.device
|
||||||
|
|
||||||
|
# Initialize attention mask as False (cannot attend)
|
||||||
|
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# Track positions for each segment
|
||||||
|
positions = []
|
||||||
|
current_pos = 0
|
||||||
|
for seg_type, seg_len in att_mask_segments:
|
||||||
|
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||||
|
current_pos += seg_len
|
||||||
|
|
||||||
|
# Apply attention rules
|
||||||
|
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||||
|
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||||
|
# Images and Language can attend to each other bidirectionally
|
||||||
|
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# Subtask tokens attend to images + language
|
||||||
|
elif query_type == 'subtask' and key_type in ['image', 'language']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# Subtask tokens attend causally to themselves
|
||||||
|
elif query_type == 'subtask' and key_type == 'subtask':
|
||||||
|
# Create causal mask for subtask tokens
|
||||||
|
subtask_len = query_end - query_start
|
||||||
|
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||||
|
|
||||||
|
# FAST tokens attend to images + language + subtask
|
||||||
|
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||||
|
|
||||||
|
# FAST tokens attend causally to themselves
|
||||||
|
elif query_type == 'fast' and key_type == 'fast':
|
||||||
|
fast_len = query_end - query_start
|
||||||
|
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||||
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||||
|
|
||||||
|
# Apply padding masks
|
||||||
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||||
|
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
|
return att_2d_masks
|
||||||
|
|
||||||
|
def visualize_attention_mask(
|
||||||
|
self,
|
||||||
|
att_mask_segments,
|
||||||
|
att_2d_masks,
|
||||||
|
save_path,
|
||||||
|
batch_idx=0,
|
||||||
|
dpi=150,
|
||||||
|
max_display_tokens=None
|
||||||
|
):
|
||||||
|
"""Visualize the attention mask with labeled segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
att_mask_segments: List of (type, length) tuples defining the segments
|
||||||
|
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
|
||||||
|
save_path: Path where to save the visualization image
|
||||||
|
batch_idx: Which batch item to visualize (default: 0)
|
||||||
|
dpi: DPI for the saved image (default: 150)
|
||||||
|
max_display_tokens: Maximum number of tokens to display (for very long sequences)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("matplotlib not available, skipping attention mask visualization")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract the mask for the specified batch
|
||||||
|
mask = att_2d_masks[batch_idx].cpu().float().numpy()
|
||||||
|
|
||||||
|
# If sequence is too long, downsample for visualization
|
||||||
|
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
|
||||||
|
# Simple downsampling by taking every Nth token
|
||||||
|
step = mask.shape[0] // max_display_tokens
|
||||||
|
mask = mask[::step, ::step]
|
||||||
|
# Adjust segments accordingly
|
||||||
|
att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments]
|
||||||
|
|
||||||
|
# Calculate positions for each segment
|
||||||
|
positions = []
|
||||||
|
current_pos = 0
|
||||||
|
for seg_type, seg_len in att_mask_segments:
|
||||||
|
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||||
|
current_pos += seg_len
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 10))
|
||||||
|
|
||||||
|
# Create custom colormap: white for False (no attention), blue for True (attention)
|
||||||
|
colors = ['white', '#2E86AB']
|
||||||
|
n_bins = 2
|
||||||
|
cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins)
|
||||||
|
|
||||||
|
# Display the mask
|
||||||
|
im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1)
|
||||||
|
|
||||||
|
# Add colorbar
|
||||||
|
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
||||||
|
cbar.set_label('Attention Enabled', rotation=270, labelpad=20)
|
||||||
|
cbar.set_ticks([0.25, 0.75])
|
||||||
|
cbar.set_ticklabels(['No', 'Yes'])
|
||||||
|
|
||||||
|
# Define colors for each segment type
|
||||||
|
segment_colors = {
|
||||||
|
'image': '#A23B72',
|
||||||
|
'language': '#F18F01',
|
||||||
|
'subtask': '#C73E1D',
|
||||||
|
'fast': '#6A994E'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Draw segment boundaries and labels
|
||||||
|
for seg_type, start, end in positions:
|
||||||
|
color = segment_colors.get(seg_type, '#666666')
|
||||||
|
|
||||||
|
# Draw vertical lines for columns (keys)
|
||||||
|
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||||
|
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||||
|
|
||||||
|
# Draw horizontal lines for rows (queries)
|
||||||
|
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||||
|
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||||
|
|
||||||
|
# Add labels at the top
|
||||||
|
mid_pos = (start + end) / 2
|
||||||
|
ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})",
|
||||||
|
ha='center', va='top', fontsize=10, fontweight='bold', color=color)
|
||||||
|
|
||||||
|
# Add labels on the left
|
||||||
|
ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})",
|
||||||
|
ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0)
|
||||||
|
|
||||||
|
# Set axis labels
|
||||||
|
ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold')
|
||||||
|
ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold')
|
||||||
|
ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)',
|
||||||
|
fontsize=14, fontweight='bold', pad=20)
|
||||||
|
|
||||||
|
# Create legend for segment types
|
||||||
|
legend_patches = []
|
||||||
|
attention_rules = {
|
||||||
|
'image': 'Bidirectional with lang',
|
||||||
|
'language': 'Bidirectional with images',
|
||||||
|
'subtask': 'Attends to img+lang, causal self',
|
||||||
|
'fast': 'Attends to all, causal self'
|
||||||
|
}
|
||||||
|
for seg_type, color in segment_colors.items():
|
||||||
|
if any(seg[0] == seg_type for seg in att_mask_segments):
|
||||||
|
rule = attention_rules.get(seg_type, '')
|
||||||
|
legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}'))
|
||||||
|
|
||||||
|
ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0),
|
||||||
|
framealpha=0.9, fontsize=9)
|
||||||
|
|
||||||
|
# Adjust layout and save
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Ensure the directory exists
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logging.info(f"Attention mask visualization saved to: {save_path}")
|
||||||
|
|
||||||
def sample_noise(self, shape, device):
|
def sample_noise(self, shape, device):
|
||||||
return torch.normal(
|
return torch.normal(
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
@@ -609,8 +809,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return time.to(dtype=torch.float32, device=device)
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
|
self,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
images,
|
||||||
|
img_masks,
|
||||||
|
tokens,
|
||||||
|
subtask_tokens,
|
||||||
|
masks,
|
||||||
|
subtask_masks,
|
||||||
|
fast_action_tokens=None,
|
||||||
|
fast_action_masks=None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||||
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -619,17 +827,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
tokens: Language instruction tokens
|
tokens: Language instruction tokens
|
||||||
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||||
masks: Attention masks for tokens
|
masks: Attention masks for tokens
|
||||||
|
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
||||||
|
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
|
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
||||||
pad_masks: Padding masks
|
pad_masks: Padding masks
|
||||||
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
|
att_masks: Custom 2D attention mask implementing the required pattern
|
||||||
total_T_images: Total number of image tokens
|
total_T_images: Total number of image tokens
|
||||||
|
num_subtask_embs: Number of subtask token embeddings
|
||||||
|
num_fast_embs: Number of FAST action token embeddings
|
||||||
"""
|
"""
|
||||||
embs = []
|
embs = []
|
||||||
pad_masks = []
|
pad_masks = []
|
||||||
att_masks = []
|
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||||
total_T_images = 0
|
total_T_images = 0
|
||||||
|
num_subtask_embs = 0
|
||||||
|
num_fast_embs = 0
|
||||||
|
|
||||||
# Process images
|
# Process images
|
||||||
for img, img_mask in zip(images, img_masks, strict=True):
|
for img, img_mask in zip(images, img_masks, strict=True):
|
||||||
@@ -642,7 +856,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
embs.append(img_emb)
|
embs.append(img_emb)
|
||||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||||
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
att_mask_segments.append(('image', num_img_embs))
|
||||||
total_T_images += num_img_embs
|
total_T_images += num_img_embs
|
||||||
|
|
||||||
# Process language instruction tokens
|
# Process language instruction tokens
|
||||||
@@ -656,7 +870,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
pad_masks.append(masks)
|
pad_masks.append(masks)
|
||||||
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
|
att_mask_segments.append(('language', num_lang_embs))
|
||||||
|
|
||||||
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
||||||
if subtask_tokens is not None:
|
if subtask_tokens is not None:
|
||||||
@@ -672,18 +886,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
pad_masks.append(subtask_masks)
|
pad_masks.append(subtask_masks)
|
||||||
|
|
||||||
num_subtask_embs = subtask_emb.shape[1]
|
num_subtask_embs = subtask_emb.shape[1]
|
||||||
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
|
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||||
# and previous subtask tokens
|
# Process FAST action tokens if provided (these are discrete token IDs)
|
||||||
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
|
if fast_action_tokens is not None:
|
||||||
|
def fast_action_embed_func(fast_action_tokens):
|
||||||
|
fast_emb = self.fast_action_embedding(fast_action_tokens)
|
||||||
|
fast_emb_dim = fast_emb.shape[-1]
|
||||||
|
return fast_emb * math.sqrt(fast_emb_dim)
|
||||||
|
|
||||||
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||||
|
embs.append(fast_action_emb)
|
||||||
|
|
||||||
|
# Use provided mask or create default (all valid)
|
||||||
|
if fast_action_masks is not None:
|
||||||
|
fast_pad_mask = fast_action_masks
|
||||||
|
else:
|
||||||
|
bsize = fast_action_tokens.shape[0]
|
||||||
|
num_fast_embs = fast_action_tokens.shape[1]
|
||||||
|
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
|
||||||
|
|
||||||
|
num_fast_embs = fast_action_tokens.shape[1]
|
||||||
|
pad_masks.append(fast_pad_mask)
|
||||||
|
att_mask_segments.append(('fast', num_fast_embs))
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
||||||
|
|
||||||
bsize = pad_masks.shape[0]
|
# Create custom 2D attention mask
|
||||||
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
|
# Attention rules:
|
||||||
|
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||||
|
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||||
|
# - FAST: attend to images + language + subtask, causal among themselves
|
||||||
|
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
||||||
|
|
||||||
return embs, pad_masks, att_masks, total_T_images
|
# # Optionally visualize the attention mask
|
||||||
|
# self.visualize_attention_mask(
|
||||||
|
# att_mask_segments=att_mask_segments,
|
||||||
|
# att_2d_masks=att_masks,
|
||||||
|
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
||||||
|
# batch_idx=0,
|
||||||
|
# max_display_tokens=512 # Limit display for very long sequences
|
||||||
|
# )
|
||||||
|
|
||||||
|
return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs
|
||||||
|
|
||||||
def embed_suffix(self, noisy_actions, timestep):
|
def embed_suffix(self, noisy_actions, timestep):
|
||||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||||
@@ -732,8 +977,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
|
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks)
|
||||||
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
|
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss.
|
"""Do a full training forward pass and compute the loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -743,7 +988,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
high_level_task_masks: Attention masks for high_level_task
|
high_level_task_masks: Attention masks for high_level_task
|
||||||
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
|
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
|
||||||
subtask_masks: Attention masks for subtask_tokens
|
subtask_masks: Attention masks for subtask_tokens
|
||||||
actions: Ground truth actions
|
actions: Ground truth actions [B, chunk_size, action_dim]
|
||||||
|
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
|
||||||
|
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
|
||||||
noise: Optional noise for flow matching
|
noise: Optional noise for flow matching
|
||||||
time: Optional time for flow matching
|
time: Optional time for flow matching
|
||||||
"""
|
"""
|
||||||
@@ -757,39 +1004,45 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
# Embed prefix (images + high_level_task + subtask_tokens)
|
# Initialize FAST loss to 0 (will be computed only if FAST tokens are provided)
|
||||||
# Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict
|
fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype)
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
|
||||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks
|
# ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ==========
|
||||||
|
# Only run this pass if FAST action tokens are provided
|
||||||
|
if fast_action_tokens is not None and fast_action_masks is not None:
|
||||||
|
# Embed prefix (images + high_level_task + subtask_tokens + FAST tokens)
|
||||||
|
# FAST tokens are provided as discrete token IDs
|
||||||
|
prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix(
|
||||||
|
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||||
|
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
|
if (
|
||||||
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
|
== torch.bfloat16
|
||||||
|
):
|
||||||
|
prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Prepare attention masks for prefix-only pass (for subtask token prediction)
|
# Prepare attention masks for prefix pass with FAST tokens
|
||||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1
|
||||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype)
|
||||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
|
||||||
|
|
||||||
# prefix-only transformer run for subtask token prediction
|
# Forward pass through paligemma for subtask + FAST prediction
|
||||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
(prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=att_2d_prefix_4d,
|
attention_mask=att_2d_prefix_with_fast_4d,
|
||||||
position_ids=position_ids_prefix,
|
position_ids=position_ids_prefix_with_fast,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=[prefix_embs, None], # SUFFIX = None
|
inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
adarms_cond=[None, None],
|
adarms_cond=[None, None],
|
||||||
)
|
)
|
||||||
|
|
||||||
# LM HEAD → SUBTASK LOGITS
|
# LM HEAD → SUBTASK LOGITS
|
||||||
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask
|
|
||||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab)
|
||||||
|
|
||||||
# Extract logits for subtask token prediction
|
# Extract logits for subtask token prediction
|
||||||
# In autoregressive modeling, output at position i predicts token at position i+1
|
|
||||||
# So we take logits from one position earlier:
|
|
||||||
# - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0]
|
|
||||||
# - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc.
|
|
||||||
T_high_level_task = high_level_task.size(1)
|
T_high_level_task = high_level_task.size(1)
|
||||||
T_subtask = subtask_tokens.size(1)
|
T_subtask = subtask_tokens.size(1)
|
||||||
start_index = total_T_images + T_high_level_task
|
start_index = total_T_images + T_high_level_task
|
||||||
@@ -797,35 +1050,137 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
|
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
|
||||||
|
|
||||||
targets = subtask_tokens # (B, T_subtask)
|
targets = subtask_tokens # (B, T_subtask)
|
||||||
# Compute cross-entropy loss
|
# Compute cross-entropy loss for subtask
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
# Reshape for loss computation
|
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
|
targets_flat = targets.reshape(-1)
|
||||||
targets_flat = targets.reshape(-1) # (B*T_subtask)
|
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||||
|
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||||
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
|
|
||||||
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
|
|
||||||
|
|
||||||
# Apply mask and compute mean loss over valid tokens
|
|
||||||
masked_loss = loss_per_token * subtask_masks.float()
|
masked_loss = loss_per_token * subtask_masks.float()
|
||||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||||
|
|
||||||
|
# Extract outputs for FAST action token prediction and compute auxiliary loss
|
||||||
|
# FAST outputs start after subtask tokens
|
||||||
|
# Similar to subtask, we use autoregressive prediction where position i predicts token i+1
|
||||||
|
fast_start_index = end_index
|
||||||
|
fast_end_index = fast_start_index + num_fast_embs
|
||||||
|
|
||||||
|
# Get logits for FAST action tokens using the FAST LM head
|
||||||
|
fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size)
|
||||||
|
|
||||||
|
# Extract logits for FAST token prediction (autoregressive: position i predicts token i+1)
|
||||||
|
# - Position (fast_start_index-1) predicts fast_action_tokens[0]
|
||||||
|
# - Position (fast_start_index) predicts fast_action_tokens[1], etc.
|
||||||
|
fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size)
|
||||||
|
|
||||||
|
# Compute cross-entropy loss for FAST action tokens
|
||||||
|
fast_targets = fast_action_tokens # (B, max_action_tokens)
|
||||||
|
loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
|
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size)
|
||||||
|
fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens)
|
||||||
|
|
||||||
|
fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens)
|
||||||
|
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens)
|
||||||
|
|
||||||
|
# Apply mask and compute mean loss over valid tokens
|
||||||
|
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||||
|
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||||
|
else:
|
||||||
|
# If no FAST tokens provided, compute subtask loss without FAST tokens
|
||||||
|
# This is the fallback for backward compatibility
|
||||||
|
prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix(
|
||||||
|
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||||
|
fast_action_tokens=None, fast_action_masks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
|
if (
|
||||||
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
|
== torch.bfloat16
|
||||||
|
):
|
||||||
|
prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1
|
||||||
|
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype)
|
||||||
|
|
||||||
|
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||||
|
attention_mask=att_2d_prefix_4d,
|
||||||
|
position_ids=position_ids_prefix,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs_for_subtask, None],
|
||||||
|
use_cache=False,
|
||||||
|
adarms_cond=[None, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
|
logits = lm_head(prefix_out)
|
||||||
|
|
||||||
|
T_high_level_task = high_level_task.size(1)
|
||||||
|
T_subtask = subtask_tokens.size(1)
|
||||||
|
start_index = total_T_images + T_high_level_task
|
||||||
|
end_index = start_index + T_subtask
|
||||||
|
logits_subtask = logits[:, start_index-1:end_index-1, :]
|
||||||
|
|
||||||
|
targets = subtask_tokens
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
|
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||||
|
targets_flat = targets.reshape(-1)
|
||||||
|
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||||
|
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||||
|
masked_loss = loss_per_token * subtask_masks.float()
|
||||||
|
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||||
|
|
||||||
|
# ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ==========
|
||||||
|
# Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens)
|
||||||
|
prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix(
|
||||||
|
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||||
|
fast_action_tokens=None, fast_action_masks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
# Convert embeddings to bfloat16 if needed for the model
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
):
|
):
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
|
# For the flow matching pass, we need custom attention where:
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
# - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention)
|
||||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
# - suffix attends to all prefix + causal to itself
|
||||||
|
# We'll construct this by extending prefix_att_masks_no_fast to include suffix
|
||||||
|
|
||||||
|
# prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len]
|
||||||
|
# We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len]
|
||||||
|
|
||||||
|
bsize = prefix_pad_masks_no_fast.shape[0]
|
||||||
|
prefix_len = prefix_pad_masks_no_fast.shape[1]
|
||||||
|
suffix_len = suffix_pad_masks.shape[1]
|
||||||
|
total_len = prefix_len + suffix_len
|
||||||
|
device = prefix_pad_masks_no_fast.device
|
||||||
|
|
||||||
|
# Create full attention mask
|
||||||
|
full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# Copy prefix attention pattern
|
||||||
|
full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast
|
||||||
|
|
||||||
|
# Suffix attends to all prefix
|
||||||
|
full_att_2d_masks[:, prefix_len:, :prefix_len] = True
|
||||||
|
|
||||||
|
# Suffix has causal attention among itself
|
||||||
|
suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device))
|
||||||
|
full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :]
|
||||||
|
|
||||||
|
# Apply padding masks
|
||||||
|
pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1)
|
||||||
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||||
|
full_att_2d_masks = full_att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
# Prepare attention masks for full forward pass (prefix + suffix)
|
|
||||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
||||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype)
|
||||||
|
|
||||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||||
@@ -836,11 +1191,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
adarms_cond=[None, adarms_cond],
|
adarms_cond=[None, adarms_cond],
|
||||||
)
|
)
|
||||||
# prefix_out to be used for the language head
|
|
||||||
return suffix_out
|
return suffix_out
|
||||||
|
|
||||||
suffix_out = self._apply_checkpoint(
|
suffix_out = self._apply_checkpoint(
|
||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
@@ -856,78 +1210,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return {
|
return {
|
||||||
"flow_loss": fm_loss,
|
"flow_loss": fm_loss,
|
||||||
"subtask_loss": subtask_loss,
|
"subtask_loss": subtask_loss,
|
||||||
"loss": 10 * fm_loss.mean() + subtask_loss,
|
"fast_loss": fast_loss,
|
||||||
|
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _generate_subtask_tokens(
|
def _generate_subtask_tokens(
|
||||||
self, images, img_masks, tokens, masks, tokenizer, max_length, device
|
self, images, img_masks, tokens, masks, tokenizer, max_length, device
|
||||||
):
|
):
|
||||||
"""Generate subtask tokens autoregressively using next token prediction."""
|
|
||||||
bsize = tokens.shape[0]
|
bsize = tokens.shape[0]
|
||||||
|
|
||||||
# Get lm_head for token generation
|
|
||||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
|
|
||||||
# Embed prefix without subtask tokens first
|
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix(
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None,
|
||||||
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None
|
fast_action_tokens=None, fast_action_masks=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize generated tokens list - start with BOS token or first token after instruction
|
|
||||||
# For PaliGemma, we'll start generation and accumulate tokens
|
|
||||||
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
||||||
|
|
||||||
for t in range(max_length):
|
# tracking mask: False = still generating, True = finished
|
||||||
# Prepare attention masks for current prefix
|
finished = torch.zeros(bsize, dtype=torch.bool, device=device)
|
||||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
||||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
for t in range(max_length):
|
||||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||||
|
|
||||||
# Forward pass through model to get logits
|
|
||||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=att_2d_prefix_4d,
|
attention_mask=att_2d_prefix_4d,
|
||||||
position_ids=position_ids_prefix,
|
position_ids=position_ids_prefix,
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=[prefix_embs, None],
|
inputs_embeds=[prefix_embs, None],
|
||||||
use_cache=False,
|
# ...
|
||||||
adarms_cond=[None, None],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get logits from the last position
|
logits = lm_head(prefix_out)
|
||||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
next_token_logits = logits[:, -1, :]
|
||||||
next_token_logits = logits[:, -1, :] # (B, vocab)
|
|
||||||
|
|
||||||
# Greedy decoding - take the most likely token
|
|
||||||
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
||||||
|
|
||||||
# Store generated token
|
# 1. if a row was already finished, force the next token to be PAD (0)
|
||||||
|
next_token = torch.where(finished, torch.tensor(0, device=device), next_token)
|
||||||
|
|
||||||
|
# 2. store the token
|
||||||
generated_tokens[:, t] = next_token
|
generated_tokens[:, t] = next_token
|
||||||
|
|
||||||
# Check for EOS token - if all batches have generated EOS, stop
|
# 3. update the finished mask
|
||||||
if tokenizer.eos_token_id is not None:
|
if tokenizer.eos_token_id is not None:
|
||||||
if (next_token == tokenizer.eos_token_id).all():
|
finished |= (next_token == tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
# 4. break only if everyone is finished
|
||||||
|
if finished.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
# Embed the generated token and append to prefix
|
next_token_unsqueezed = next_token.unsqueeze(1)
|
||||||
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
|
||||||
|
|
||||||
def next_token_embed_func(next_token_unsqueezed):
|
def next_token_embed_func(next_token_unsqueezed):
|
||||||
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
||||||
next_emb_dim = next_emb.shape[-1]
|
return next_emb * math.sqrt(next_emb.shape[-1])
|
||||||
return next_emb * math.sqrt(next_emb_dim)
|
|
||||||
|
|
||||||
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
|
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
|
||||||
|
|
||||||
# Append to prefix embeddings
|
# update embeddings
|
||||||
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
|
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
|
||||||
|
|
||||||
# Update masks - new token is valid and uses causal attention
|
# update padding masks
|
||||||
prefix_pad_masks = torch.cat([
|
prefix_pad_masks = torch.cat([
|
||||||
prefix_pad_masks,
|
prefix_pad_masks,
|
||||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||||
], dim=1)
|
], dim=1)
|
||||||
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
|
||||||
|
# update attention masks
|
||||||
|
old_seq_len = prefix_att_masks.shape[1]
|
||||||
|
new_seq_len = old_seq_len + 1
|
||||||
|
new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device)
|
||||||
|
new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks
|
||||||
|
new_att_masks[:, -1, :] = prefix_pad_masks
|
||||||
|
prefix_att_masks = new_att_masks
|
||||||
|
|
||||||
return generated_tokens
|
return generated_tokens
|
||||||
|
|
||||||
@@ -978,13 +1334,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
|
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
|
||||||
|
|
||||||
# During inference, we don't have subtask_tokens yet, so pass None
|
# During inference, we don't have subtask_tokens yet, so pass None
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
# Also no FAST tokens during inference
|
||||||
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks
|
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix(
|
||||||
|
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks,
|
||||||
|
fast_action_tokens=None, fast_action_masks=None
|
||||||
)
|
)
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
||||||
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
|
if (
|
||||||
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
|
== torch.bfloat16
|
||||||
|
):
|
||||||
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# prefix_att_masks is already a 2D attention mask from embed_prefix
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
@@ -1209,7 +1575,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
print(f"Remapped {remap_count} state dict keys")
|
print(f"Remapped {remap_count} state dict keys")
|
||||||
|
|
||||||
# Load the remapped state dict into the model
|
# Load the remapped state dict into the model
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
|
||||||
|
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||||
@@ -1442,19 +1808,26 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
# Decode and print ground truth subtask tokens during training
|
# Decode and print ground truth subtask tokens during training
|
||||||
if self.tokenizer is not None and self.training:
|
# if self.tokenizer is not None and self.training:
|
||||||
bsize = subtask_tokens.shape[0]
|
# bsize = subtask_tokens.shape[0]
|
||||||
for i in range(bsize):
|
# for i in range(bsize):
|
||||||
# Remove padding tokens (0) and special tokens
|
# # Remove padding tokens (0) and special tokens
|
||||||
valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
||||||
if len(valid_tokens) > 0:
|
# # if len(valid_tokens) > 0:
|
||||||
decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||||
# print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
||||||
|
|
||||||
|
# Get FAST action tokens from batch
|
||||||
|
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
|
||||||
|
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
||||||
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
||||||
loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions)
|
# fast_action_tokens = discrete action token IDs to predict
|
||||||
|
loss_dict = self.model.forward(
|
||||||
|
images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions,
|
||||||
|
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||||
|
)
|
||||||
|
|
||||||
# Extract the total loss
|
# Extract the total loss
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
@@ -1464,6 +1837,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||||
"subtask_loss": loss_dict["subtask_loss"].item(),
|
"subtask_loss": loss_dict["subtask_loss"].item(),
|
||||||
|
"fast_loss": loss_dict["fast_loss"].item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, detailed_loss_dict
|
return loss, detailed_loss_dict
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from lerobot.processor import (
|
|||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
|
ActionTokenizerProcessorStep,
|
||||||
TokenizerProcessorStep,
|
TokenizerProcessorStep,
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
)
|
)
|
||||||
@@ -158,7 +159,6 @@ def make_pi05_pre_post_processors(
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Add remaining processors
|
# Add remaining processors
|
||||||
input_steps: list[ProcessorStep] = [
|
input_steps: list[ProcessorStep] = [
|
||||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||||
@@ -177,6 +177,9 @@ def make_pi05_pre_post_processors(
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
),
|
),
|
||||||
|
ActionTokenizerProcessorStep(
|
||||||
|
tokenizer_name="physical-intelligence/fast",
|
||||||
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
|
|||||||
RobotActionToPolicyActionProcessorStep,
|
RobotActionToPolicyActionProcessorStep,
|
||||||
)
|
)
|
||||||
from .rename_processor import RenameObservationsProcessorStep
|
from .rename_processor import RenameObservationsProcessorStep
|
||||||
from .tokenizer_processor import TokenizerProcessorStep
|
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ActionProcessorStep",
|
"ActionProcessorStep",
|
||||||
|
|||||||
@@ -15,10 +15,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script defines a processor for tokenizing natural language instructions from an environment transition.
|
This script defines processors for tokenizing data from an environment transition.
|
||||||
|
|
||||||
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
|
It includes:
|
||||||
token IDs and attention masks, which are then added to the observation dictionary.
|
- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert
|
||||||
|
task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary.
|
||||||
|
- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer)
|
||||||
|
to tokenize action tensors into discrete token IDs for action modeling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -30,6 +33,8 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
|
ACTION_TOKEN_MASK,
|
||||||
|
ACTION_TOKENS,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||||
@@ -40,12 +45,13 @@ from lerobot.utils.constants import (
|
|||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, TransitionKey
|
||||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
else:
|
else:
|
||||||
|
AutoProcessor = None
|
||||||
AutoTokenizer = None
|
AutoTokenizer = None
|
||||||
|
|
||||||
|
|
||||||
@@ -423,3 +429,233 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
|
||||||
|
class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||||
|
"""
|
||||||
|
Processor step to tokenize action data using a fast action tokenizer.
|
||||||
|
|
||||||
|
This step takes action tensors from an `EnvTransition`, tokenizes them using
|
||||||
|
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||||
|
and returns the tokenized action.
|
||||||
|
|
||||||
|
Requires the `transformers` library to be installed.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||||
|
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||||
|
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||||
|
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer_name: str | None = None
|
||||||
|
tokenizer: Any | None = None
|
||||||
|
trust_remote_code: bool = True
|
||||||
|
max_action_tokens: int = 32
|
||||||
|
# Internal tokenizer instance (not part of the config)
|
||||||
|
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the action tokenizer after the dataclass is created.
|
||||||
|
|
||||||
|
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||||
|
either from a provided object or by name from the Hugging Face Hub.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the `transformers` library is not installed.
|
||||||
|
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||||
|
"""
|
||||||
|
if not _transformers_available:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'transformers' library is not installed. "
|
||||||
|
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
# Use provided tokenizer object directly
|
||||||
|
self.action_tokenizer = self.tokenizer
|
||||||
|
elif self.tokenizer_name is not None:
|
||||||
|
if AutoProcessor is None:
|
||||||
|
raise ImportError("AutoProcessor is not available")
|
||||||
|
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||||
|
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||||
|
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""
|
||||||
|
Applies action tokenization to the transition.
|
||||||
|
|
||||||
|
This overrides the base class to handle both tokens and mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transition: The input transition with action data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed transition with tokenized actions and mask in complementary data.
|
||||||
|
"""
|
||||||
|
self._current_transition = transition.copy()
|
||||||
|
new_transition = self._current_transition
|
||||||
|
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None:
|
||||||
|
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
|
||||||
|
|
||||||
|
# Tokenize and get both tokens and mask
|
||||||
|
tokens, mask = self._tokenize_action(action)
|
||||||
|
|
||||||
|
# Store mask in complementary data
|
||||||
|
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
if complementary_data is None:
|
||||||
|
complementary_data = {}
|
||||||
|
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||||
|
complementary_data[ACTION_TOKENS] = tokens
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Tokenizes the action tensor and creates a mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (tokens, mask) where:
|
||||||
|
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
|
||||||
|
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
|
||||||
|
"""
|
||||||
|
if action is None:
|
||||||
|
raise ValueError("Action cannot be None")
|
||||||
|
|
||||||
|
# Get the device and dtype of the input action
|
||||||
|
device = action.device if isinstance(action, torch.Tensor) else None
|
||||||
|
|
||||||
|
# Handle single sample (add batch dimension)
|
||||||
|
single_sample = action.dim() == 1
|
||||||
|
if single_sample:
|
||||||
|
action = action.unsqueeze(0)
|
||||||
|
|
||||||
|
batch_size = action.shape[0]
|
||||||
|
|
||||||
|
# Tokenize the action batch
|
||||||
|
# The fast tokenizer expects action data and returns token IDs
|
||||||
|
tokens_list = []
|
||||||
|
masks_list = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||||
|
action_cpu = action[i:i+1].cpu()
|
||||||
|
tokens = self.action_tokenizer(action_cpu)
|
||||||
|
|
||||||
|
# Convert to numpy array if it's a list
|
||||||
|
if isinstance(tokens, list):
|
||||||
|
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||||
|
elif not isinstance(tokens, torch.Tensor):
|
||||||
|
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||||
|
else:
|
||||||
|
# Move tokens back to the same device as input action
|
||||||
|
tokens = tokens.to(device=action.device)
|
||||||
|
|
||||||
|
# Flatten to 1D if needed
|
||||||
|
if tokens.dim() > 1:
|
||||||
|
tokens = tokens.flatten()
|
||||||
|
|
||||||
|
# Truncate or pad to max_action_tokens
|
||||||
|
if len(tokens) > self.max_action_tokens:
|
||||||
|
tokens = tokens[:self.max_action_tokens]
|
||||||
|
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||||
|
else:
|
||||||
|
mask = torch.cat([
|
||||||
|
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||||
|
torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device)
|
||||||
|
])
|
||||||
|
# Pad tokens with zeros
|
||||||
|
tokens = torch.nn.functional.pad(
|
||||||
|
tokens,
|
||||||
|
(0, self.max_action_tokens - len(tokens)),
|
||||||
|
value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens_list.append(tokens)
|
||||||
|
masks_list.append(mask)
|
||||||
|
|
||||||
|
# Stack into batched tensors
|
||||||
|
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||||
|
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||||
|
|
||||||
|
# Remove batch dimension if input was single sample
|
||||||
|
if single_sample:
|
||||||
|
tokens_batch = tokens_batch.squeeze(0)
|
||||||
|
masks_batch = masks_batch.squeeze(0)
|
||||||
|
|
||||||
|
# Move to the same device as the input
|
||||||
|
if device is not None:
|
||||||
|
tokens_batch = tokens_batch.to(device)
|
||||||
|
masks_batch = masks_batch.to(device)
|
||||||
|
|
||||||
|
return tokens_batch, masks_batch
|
||||||
|
|
||||||
|
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method is not used since we override __call__.
|
||||||
|
Required by ActionProcessorStep ABC.
|
||||||
|
"""
|
||||||
|
tokens, _ = self._tokenize_action(action)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Returns the serializable configuration of the processor.
|
||||||
|
|
||||||
|
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||||
|
with a tokenizer name, that name will be included in the config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the processor's configuration parameters.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"trust_remote_code": self.trust_remote_code,
|
||||||
|
"max_action_tokens": self.max_action_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only save tokenizer_name if it was used to create the tokenizer
|
||||||
|
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||||
|
config["tokenizer_name"] = self.tokenizer_name
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
Updates feature definitions to reflect tokenized actions.
|
||||||
|
|
||||||
|
This updates the policy features dictionary to indicate that the action
|
||||||
|
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: The dictionary of existing policy features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated dictionary of policy features.
|
||||||
|
"""
|
||||||
|
# Update the action feature to reflect the tokenized shape
|
||||||
|
# The action is now a sequence of token IDs
|
||||||
|
if PipelineFeatureType.ACTION in features:
|
||||||
|
# Replace the action feature with the tokenized version
|
||||||
|
features[PipelineFeatureType.ACTION] = {
|
||||||
|
ACTION_TOKENS: PolicyFeature(
|
||||||
|
type=FeatureType.SEQUENCE, # Token sequence
|
||||||
|
shape=(self.max_action_tokens,)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return features
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def update_policy(
|
|||||||
# Let accelerator handle mixed precision
|
# Let accelerator handle mixed precision
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
|
action = policy.select_action(batch)
|
||||||
breakpoint()
|
breakpoint()
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
|
|||||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
|
ACTION_TOKENS = ACTION + ".tokens"
|
||||||
|
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
TRUNCATED = "next.truncated"
|
TRUNCATED = "next.truncated"
|
||||||
DONE = "next.done"
|
DONE = "next.done"
|
||||||
|
|||||||
Reference in New Issue
Block a user