add fast tokenizer support

This commit is contained in:
Jade Choghari
2025-12-16 11:28:27 +00:00
parent fddd044306
commit 8e05dc9a7a
13 changed files with 1081 additions and 204 deletions
@@ -0,0 +1,138 @@
#!/usr/bin/env python
"""
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
This example shows how to:
1. Load a dataset with action data
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
3. Access both the tokenized actions and the attention mask
4. Decode tokenized actions back to their original form
"""
import torch
from transformers import AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
from lerobot.utils.constants import ACTION_TOKEN_MASK
# Define delta timestamps for the dataset
delta_timestamps = {
'action': [
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
]
}
# Load the dataset
print("Loading dataset...")
dataset = LeRobotDataset(
repo_id="local",
root="/fsx/jade_choghari/outputs/pgen_annotations1",
delta_timestamps=delta_timestamps
)
# Create a dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
# Get a batch of data
batch = next(iter(dataloader))
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
print(f"\nOriginal action shape: {action_data.shape}")
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
print("\n" + "="*80)
print("Method 1: Direct tokenizer usage")
print("="*80)
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize directly
tokens = tokenizer(action_data)
print(f"\nDirect tokenization result type: {type(tokens)}")
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
# Decode
decoded_actions = tokenizer.decode(tokens)
print(f"Decoded actions shape: {decoded_actions.shape}")
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
print("\n" + "="*80)
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
print("="*80)
# Create the action tokenizer processor step
action_tokenizer_processor = ActionTokenizerProcessorStep(
tokenizer_name="physical-intelligence/fast",
trust_remote_code=True,
max_action_tokens=32, # Maximum number of tokens per action
)
# Create a transition with the action data
transition = {
TransitionKey.ACTION: action_data,
TransitionKey.OBSERVATION: {}, # Empty for this example
}
# Apply the processor
processed_transition = action_tokenizer_processor(transition)
# Extract tokenized actions and mask
tokenized_actions = processed_transition[TransitionKey.ACTION]
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
action_mask = complementary_data[ACTION_TOKEN_MASK]
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
print(f"Action mask dtype: {action_mask.dtype}")
# Show token statistics
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
print(f"First sample mask: {action_mask[0]}")
num_real_tokens = action_mask[0].sum().item()
print(f"Number of real tokens (non-padding): {num_real_tokens}")
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
# Decode using the mask
print("\nDecoding tokenized actions...")
decoded_with_processor = tokenizer.decode(tokenized_actions)
print(f"Decoded actions shape: {decoded_with_processor.shape}")
# Calculate reconstruction error
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
# Show that masking works correctly
print("\n" + "="*80)
print("Mask demonstration")
print("="*80)
for i in range(min(4, tokenized_actions.shape[0])):
mask_i = action_mask[i]
num_real = mask_i.sum().item()
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
print("\n" + "="*80)
print("Action tokenization example completed successfully!")
print("="*80)
+25
View File
@@ -0,0 +1,25 @@
import numpy as np
from transformers import AutoProcessor
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
# Load the tokenizer from the Hugging Face hub
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize & decode action chunks (we use dummy data here)
action_data = batch["action"] # one batch of action chunks
tokens = tokenizer(action_data) # tokens = list[int]
decoded_actions = tokenizer.decode(tokens)
print("tokenized actions: ", tokens)
+3 -4
View File
@@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
@@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader(
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
breakpoint()
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
breakpoint()
# import requests
# from PIL import Image
# from transformers import AutoProcessor
+159
View File
@@ -0,0 +1,159 @@
## One-sentence answer
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
Everything else youve seen so far was just metadata.
---
## What goes in
### Inputs
```python
prefix_pad_masks # shape [B, L]
prefix_att_masks # shape [B, L]
```
Where:
* `prefix_pad_masks[b, i] = True`
→ token `i` exists (not padding)
* `prefix_att_masks[b, i] = False`
→ token `i` is **bidirectional**
* `prefix_att_masks[b, i] = True`
→ token `i` is **causal (autoregressive)**
---
## What comes out
```python
att_2d_prefix # shape [B, L, L]
```
Each entry:
```text
att_2d_prefix[b, i, j] = True
```
means:
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
---
## How it is constructed (conceptually)
For **each batch b**, **each query position i**, **each key position j**:
```python
if not prefix_pad_masks[b, j]:
att[b, i, j] = False # cannot attend to padding
else if not prefix_att_masks[b, i]:
att[b, i, j] = True # bidirectional token → can see all real tokens
else:
att[b, i, j] = (j <= i) # causal token → can see only past + itself
```
Thats it.
---
## Tiny concrete example (exactly matching your code)
Suppose:
```python
prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]
```
Tokens:
```
0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD
```
---
### Resulting `att_2d_prefix[0]`
`✓ = True, ✗ = False`
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
| ---------- | - | - | - | - | - | - |
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
---
## Why this matters for your training code
This line:
```python
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
```
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
This is **exactly what Paligemma uses inside self-attention**.
---
## Key implications (VERY important)
### 1️⃣ This mask does **not isolate token groups**
* Bidirectional tokens can attend to **everything**
* Causal tokens only restrict *their own row*
So **flow/action tokens must be blocked separately**.
---
### 2️⃣ This is why your AR subtask prediction works
* Subtask tokens are causal
* Output at position `i` predicts token `i+1`
* Padding is fully ignored
---
### 3️⃣ Inference behavior
When `subtask_tokens = None`:
* `prefix_att_masks` contains only `False`
* `att_2d_prefix` becomes **fully bidirectional**
* No AR behavior remains
Exactly what you want.
---
## One-sentence takeaway (commit this)
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
If you want next, I can:
* inspect `make_att_2d_masks()` source with you
* show how to block **flow → subtask** attention
* explain how this changes when suffix tokens are added
* help you refactor this into a cleaner “grouped attention” API
Youre now at the point where the models behavior should feel *predictable*, not magical.
+1
View File
@@ -0,0 +1 @@
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari