mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
add fast tokenizer support
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
|
||||
|
||||
This example shows how to:
|
||||
1. Load a dataset with action data
|
||||
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
|
||||
3. Access both the tokenized actions and the attention mask
|
||||
4. Decode tokenized actions back to their original form
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK
|
||||
|
||||
# Define delta timestamps for the dataset
|
||||
delta_timestamps = {
|
||||
'action': [
|
||||
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
|
||||
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
|
||||
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
|
||||
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
|
||||
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
|
||||
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
|
||||
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
|
||||
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
|
||||
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
|
||||
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
|
||||
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
|
||||
]
|
||||
}
|
||||
|
||||
# Load the dataset
|
||||
print("Loading dataset...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="local",
|
||||
root="/fsx/jade_choghari/outputs/pgen_annotations1",
|
||||
delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
# Create a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Get a batch of data
|
||||
batch = next(iter(dataloader))
|
||||
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
|
||||
|
||||
print(f"\nOriginal action shape: {action_data.shape}")
|
||||
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
|
||||
|
||||
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
|
||||
print("\n" + "="*80)
|
||||
print("Method 1: Direct tokenizer usage")
|
||||
print("="*80)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize directly
|
||||
tokens = tokenizer(action_data)
|
||||
print(f"\nDirect tokenization result type: {type(tokens)}")
|
||||
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
|
||||
|
||||
# Decode
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print(f"Decoded actions shape: {decoded_actions.shape}")
|
||||
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
|
||||
|
||||
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
|
||||
print("\n" + "="*80)
|
||||
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
|
||||
print("="*80)
|
||||
|
||||
# Create the action tokenizer processor step
|
||||
action_tokenizer_processor = ActionTokenizerProcessorStep(
|
||||
tokenizer_name="physical-intelligence/fast",
|
||||
trust_remote_code=True,
|
||||
max_action_tokens=32, # Maximum number of tokens per action
|
||||
)
|
||||
|
||||
# Create a transition with the action data
|
||||
transition = {
|
||||
TransitionKey.ACTION: action_data,
|
||||
TransitionKey.OBSERVATION: {}, # Empty for this example
|
||||
}
|
||||
|
||||
# Apply the processor
|
||||
processed_transition = action_tokenizer_processor(transition)
|
||||
|
||||
# Extract tokenized actions and mask
|
||||
tokenized_actions = processed_transition[TransitionKey.ACTION]
|
||||
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
action_mask = complementary_data[ACTION_TOKEN_MASK]
|
||||
|
||||
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
|
||||
print(f"Action mask dtype: {action_mask.dtype}")
|
||||
|
||||
# Show token statistics
|
||||
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
|
||||
print(f"First sample mask: {action_mask[0]}")
|
||||
num_real_tokens = action_mask[0].sum().item()
|
||||
print(f"Number of real tokens (non-padding): {num_real_tokens}")
|
||||
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
|
||||
|
||||
# Decode using the mask
|
||||
print("\nDecoding tokenized actions...")
|
||||
decoded_with_processor = tokenizer.decode(tokenized_actions)
|
||||
print(f"Decoded actions shape: {decoded_with_processor.shape}")
|
||||
|
||||
# Calculate reconstruction error
|
||||
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
|
||||
|
||||
# Show that masking works correctly
|
||||
print("\n" + "="*80)
|
||||
print("Mask demonstration")
|
||||
print("="*80)
|
||||
for i in range(min(4, tokenized_actions.shape[0])):
|
||||
mask_i = action_mask[i]
|
||||
num_real = mask_i.sum().item()
|
||||
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("Action tokenization example completed successfully!")
|
||||
print("="*80)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# Load the tokenizer from the Hugging Face hub
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize & decode action chunks (we use dummy data here)
|
||||
action_data = batch["action"] # one batch of action chunks
|
||||
tokens = tokenizer(action_data) # tokens = list[int]
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print("tokenized actions: ", tokens)
|
||||
@@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
@@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader(
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
batch = pre_processor(batch)
|
||||
breakpoint()
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
## One-sentence answer
|
||||
|
||||
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
|
||||
|
||||
Everything else you’ve seen so far was just metadata.
|
||||
|
||||
---
|
||||
|
||||
## What goes in
|
||||
|
||||
### Inputs
|
||||
|
||||
```python
|
||||
prefix_pad_masks # shape [B, L]
|
||||
prefix_att_masks # shape [B, L]
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
* `prefix_pad_masks[b, i] = True`
|
||||
→ token `i` exists (not padding)
|
||||
|
||||
* `prefix_att_masks[b, i] = False`
|
||||
→ token `i` is **bidirectional**
|
||||
|
||||
* `prefix_att_masks[b, i] = True`
|
||||
→ token `i` is **causal (autoregressive)**
|
||||
|
||||
---
|
||||
|
||||
## What comes out
|
||||
|
||||
```python
|
||||
att_2d_prefix # shape [B, L, L]
|
||||
```
|
||||
|
||||
Each entry:
|
||||
|
||||
```text
|
||||
att_2d_prefix[b, i, j] = True
|
||||
```
|
||||
|
||||
means:
|
||||
|
||||
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
|
||||
|
||||
---
|
||||
|
||||
## How it is constructed (conceptually)
|
||||
|
||||
For **each batch b**, **each query position i**, **each key position j**:
|
||||
|
||||
```python
|
||||
if not prefix_pad_masks[b, j]:
|
||||
att[b, i, j] = False # cannot attend to padding
|
||||
else if not prefix_att_masks[b, i]:
|
||||
att[b, i, j] = True # bidirectional token → can see all real tokens
|
||||
else:
|
||||
att[b, i, j] = (j <= i) # causal token → can see only past + itself
|
||||
```
|
||||
|
||||
That’s it.
|
||||
|
||||
---
|
||||
|
||||
## Tiny concrete example (exactly matching your code)
|
||||
|
||||
Suppose:
|
||||
|
||||
```python
|
||||
prefix_pad_masks[0] = [T, T, T, T, T, F]
|
||||
prefix_att_masks[0] = [F, F, F, T, T, T]
|
||||
```
|
||||
|
||||
Tokens:
|
||||
|
||||
```
|
||||
0: IMG
|
||||
1: IMG
|
||||
2: LANG
|
||||
3: SUB0
|
||||
4: SUB1
|
||||
5: PAD
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resulting `att_2d_prefix[0]`
|
||||
|
||||
`✓ = True, ✗ = False`
|
||||
|
||||
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|
||||
| ---------- | - | - | - | - | - | - |
|
||||
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
|
||||
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||
|
||||
---
|
||||
|
||||
## Why this matters for your training code
|
||||
|
||||
This line:
|
||||
|
||||
```python
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
|
||||
```
|
||||
|
||||
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
|
||||
|
||||
This is **exactly what Paligemma uses inside self-attention**.
|
||||
|
||||
---
|
||||
|
||||
## Key implications (VERY important)
|
||||
|
||||
### 1️⃣ This mask does **not isolate token groups**
|
||||
|
||||
* Bidirectional tokens can attend to **everything**
|
||||
* Causal tokens only restrict *their own row*
|
||||
|
||||
So **flow/action tokens must be blocked separately**.
|
||||
|
||||
---
|
||||
|
||||
### 2️⃣ This is why your AR subtask prediction works
|
||||
|
||||
* Subtask tokens are causal
|
||||
* Output at position `i` predicts token `i+1`
|
||||
* Padding is fully ignored
|
||||
|
||||
---
|
||||
|
||||
### 3️⃣ Inference behavior
|
||||
|
||||
When `subtask_tokens = None`:
|
||||
|
||||
* `prefix_att_masks` contains only `False`
|
||||
* `att_2d_prefix` becomes **fully bidirectional**
|
||||
* No AR behavior remains
|
||||
|
||||
Exactly what you want.
|
||||
|
||||
---
|
||||
|
||||
## One-sentence takeaway (commit this)
|
||||
|
||||
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
|
||||
|
||||
If you want next, I can:
|
||||
|
||||
* inspect `make_att_2d_masks()` source with you
|
||||
* show how to block **flow → subtask** attention
|
||||
* explain how this changes when suffix tokens are added
|
||||
* help you refactor this into a cleaner “grouped attention” API
|
||||
|
||||
You’re now at the point where the model’s behavior should feel *predictable*, not magical.
|
||||
@@ -0,0 +1 @@
|
||||
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
|
||||
Reference in New Issue
Block a user