From 8e05dc9a7ad6cdf1a02b66ac9c0df24acbb1535f Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 16 Dec 2025 11:28:27 +0000 Subject: [PATCH] add fast tokenizer support --- examples/dataset/action_tokenizer_example.py | 138 ++++ examples/dataset/fast_tokenize.py | 25 + examples/dataset/inference_pi05.py | 7 +- examples/dataset/mask.md | 159 +++++ examples/dataset/test.txt | 1 + src/lerobot/policies/pi05/compare.txt | 64 -- .../policies/pi05/configuration_pi05.py | 3 + src/lerobot/policies/pi05/modeling_pi05.py | 630 ++++++++++++++---- src/lerobot/policies/pi05/processor_pi05.py | 7 +- src/lerobot/processor/__init__.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 246 ++++++- src/lerobot/scripts/lerobot_train.py | 1 + src/lerobot/utils/constants.py | 2 + 13 files changed, 1081 insertions(+), 204 deletions(-) create mode 100644 examples/dataset/action_tokenizer_example.py create mode 100644 examples/dataset/fast_tokenize.py create mode 100644 examples/dataset/mask.md create mode 100644 examples/dataset/test.txt delete mode 100644 src/lerobot/policies/pi05/compare.txt diff --git a/examples/dataset/action_tokenizer_example.py b/examples/dataset/action_tokenizer_example.py new file mode 100644 index 000000000..1e16a3fc5 --- /dev/null +++ b/examples/dataset/action_tokenizer_example.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python + +""" +Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions. + +This example shows how to: +1. Load a dataset with action data +2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation +3. Access both the tokenized actions and the attention mask +4. Decode tokenized actions back to their original form +""" + +import torch +from transformers import AutoProcessor + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep +from lerobot.utils.constants import ACTION_TOKEN_MASK + +# Define delta timestamps for the dataset +delta_timestamps = { + 'action': [ + 0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, + 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, + 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, + 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, + 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, + 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, + 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, + 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, + 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, + 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, + 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333 + ] +} + +# Load the dataset +print("Loading dataset...") +dataset = LeRobotDataset( + repo_id="local", + root="/fsx/jade_choghari/outputs/pgen_annotations1", + delta_timestamps=delta_timestamps +) + +# Create a dataloader +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=4, + shuffle=True, +) + +# Get a batch of data +batch = next(iter(dataloader)) +action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim) + +print(f"\nOriginal action shape: {action_data.shape}") +print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}") + +# Method 1: Using the tokenizer directly (as in fast_tokenize.py) +print("\n" + "="*80) +print("Method 1: Direct tokenizer usage") +print("="*80) + +tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + +# Tokenize directly +tokens = tokenizer(action_data) +print(f"\nDirect tokenization result type: {type(tokens)}") +print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}") + +# Decode +decoded_actions = tokenizer.decode(tokens) +print(f"Decoded actions shape: {decoded_actions.shape}") +reconstruction_error = torch.abs(action_data - decoded_actions).mean() +print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}") + +# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation +print("\n" + "="*80) +print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)") +print("="*80) + +# Create the action tokenizer processor step +action_tokenizer_processor = ActionTokenizerProcessorStep( + tokenizer_name="physical-intelligence/fast", + trust_remote_code=True, + max_action_tokens=32, # Maximum number of tokens per action +) + +# Create a transition with the action data +transition = { + TransitionKey.ACTION: action_data, + TransitionKey.OBSERVATION: {}, # Empty for this example +} + +# Apply the processor +processed_transition = action_tokenizer_processor(transition) + +# Extract tokenized actions and mask +tokenized_actions = processed_transition[TransitionKey.ACTION] +complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA] +action_mask = complementary_data[ACTION_TOKEN_MASK] + +print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens) +print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens) +print(f"Tokenized actions dtype: {tokenized_actions.dtype}") +print(f"Action mask dtype: {action_mask.dtype}") + +# Show token statistics +print(f"\nFirst sample tokens: {tokenized_actions[0]}") +print(f"First sample mask: {action_mask[0]}") +num_real_tokens = action_mask[0].sum().item() +print(f"Number of real tokens (non-padding): {num_real_tokens}") +print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}") + +# Decode using the mask +print("\nDecoding tokenized actions...") +decoded_with_processor = tokenizer.decode(tokenized_actions) +print(f"Decoded actions shape: {decoded_with_processor.shape}") + +# Calculate reconstruction error +reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean() +print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}") + +# Show that masking works correctly +print("\n" + "="*80) +print("Mask demonstration") +print("="*80) +for i in range(min(4, tokenized_actions.shape[0])): + mask_i = action_mask[i] + num_real = mask_i.sum().item() + print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens") + +print("\n" + "="*80) +print("Action tokenization example completed successfully!") +print("="*80) + diff --git a/examples/dataset/fast_tokenize.py b/examples/dataset/fast_tokenize.py new file mode 100644 index 000000000..a730a31b3 --- /dev/null +++ b/examples/dataset/fast_tokenize.py @@ -0,0 +1,25 @@ +import numpy as np +from transformers import AutoProcessor +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata + +delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]} +dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps) + +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=4, + shuffle=True, +) + +batch = next(iter(dataloader)) + +# Load the tokenizer from the Hugging Face hub +tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + +# Tokenize & decode action chunks (we use dummy data here) +action_data = batch["action"] # one batch of action chunks +tokens = tokenizer(action_data) # tokens = list[int] +decoded_actions = tokenizer.decode(tokens) +print("tokenized actions: ", tokens) diff --git a/examples/dataset/inference_pi05.py b/examples/dataset/inference_pi05.py index 1590c11dc..f8e38760f 100644 --- a/examples/dataset/inference_pi05.py +++ b/examples/dataset/inference_pi05.py @@ -10,13 +10,13 @@ from lerobot.policies.factory import make_policy, make_policy_config from lerobot.configs.policies import PreTrainedConfig cfg = PreTrainedConfig.from_pretrained( - pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model", + pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model", ) cfg.dtype = "bfloat16" pre_processor, post_processor = make_pre_post_processors( policy_cfg=cfg, - pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model", + pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model", ) delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]} @@ -45,13 +45,12 @@ dataloader = torch.utils.data.DataLoader( ) batch = next(iter(dataloader)) - batch = pre_processor(batch) -breakpoint() policy.train() # run inference # action = policy.select_action(batch) loss, loss_dict = policy.forward(batch) +breakpoint() # import requests # from PIL import Image # from transformers import AutoProcessor diff --git a/examples/dataset/mask.md b/examples/dataset/mask.md new file mode 100644 index 000000000..a55a90a39 --- /dev/null +++ b/examples/dataset/mask.md @@ -0,0 +1,159 @@ +## One-sentence answer + +> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**. + +Everything else you’ve seen so far was just metadata. + +--- + +## What goes in + +### Inputs + +```python +prefix_pad_masks # shape [B, L] +prefix_att_masks # shape [B, L] +``` + +Where: + +* `prefix_pad_masks[b, i] = True` + → token `i` exists (not padding) + +* `prefix_att_masks[b, i] = False` + → token `i` is **bidirectional** + +* `prefix_att_masks[b, i] = True` + → token `i` is **causal (autoregressive)** + +--- + +## What comes out + +```python +att_2d_prefix # shape [B, L, L] +``` + +Each entry: + +```text +att_2d_prefix[b, i, j] = True +``` + +means: + +> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.” + +--- + +## How it is constructed (conceptually) + +For **each batch b**, **each query position i**, **each key position j**: + +```python +if not prefix_pad_masks[b, j]: + att[b, i, j] = False # cannot attend to padding +else if not prefix_att_masks[b, i]: + att[b, i, j] = True # bidirectional token → can see all real tokens +else: + att[b, i, j] = (j <= i) # causal token → can see only past + itself +``` + +That’s it. + +--- + +## Tiny concrete example (exactly matching your code) + +Suppose: + +```python +prefix_pad_masks[0] = [T, T, T, T, T, F] +prefix_att_masks[0] = [F, F, F, T, T, T] +``` + +Tokens: + +``` +0: IMG +1: IMG +2: LANG +3: SUB0 +4: SUB1 +5: PAD +``` + +--- + +### Resulting `att_2d_prefix[0]` + +`✓ = True, ✗ = False` + +| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 | +| ---------- | - | - | - | - | - | - | +| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | +| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | +| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | +| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | +| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | +| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | + +--- + +## Why this matters for your training code + +This line: + +```python +att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix) +``` + +Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`. + +This is **exactly what Paligemma uses inside self-attention**. + +--- + +## Key implications (VERY important) + +### 1️⃣ This mask does **not isolate token groups** + +* Bidirectional tokens can attend to **everything** +* Causal tokens only restrict *their own row* + +So **flow/action tokens must be blocked separately**. + +--- + +### 2️⃣ This is why your AR subtask prediction works + +* Subtask tokens are causal +* Output at position `i` predicts token `i+1` +* Padding is fully ignored + +--- + +### 3️⃣ Inference behavior + +When `subtask_tokens = None`: + +* `prefix_att_masks` contains only `False` +* `att_2d_prefix` becomes **fully bidirectional** +* No AR behavior remains + +Exactly what you want. + +--- + +## One-sentence takeaway (commit this) + +> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses. + +If you want next, I can: + +* inspect `make_att_2d_masks()` source with you +* show how to block **flow → subtask** attention +* explain how this changes when suffix tokens are added +* help you refactor this into a cleaner “grouped attention” API + +You’re now at the point where the model’s behavior should feel *predictable*, not magical. diff --git a/examples/dataset/test.txt b/examples/dataset/test.txt new file mode 100644 index 000000000..6fd6d8dd2 --- /dev/null +++ b/examples/dataset/test.txt @@ -0,0 +1 @@ +srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari \ No newline at end of file diff --git a/src/lerobot/policies/pi05/compare.txt b/src/lerobot/policies/pi05/compare.txt deleted file mode 100644 index aaebecec3..000000000 --- a/src/lerobot/policies/pi05/compare.txt +++ /dev/null @@ -1,64 +0,0 @@ - -Fine tune output -(Pdb) images[2].mean() -tensor(-1., device='cuda:0') -(Pdb) images[1].mean() -tensor(-0.5780, device='cuda:0') -(Pdb) images[0].mean() -tensor(-0.7716, device='cuda:0') -(Pdb) (Pdb) high_level_task[0] -tensor([ 2, 7978, 2403, 6911, 235292, 5651, 3124, 573, 18571, - 7762, 6643, 573, 9010, 72993, 21810, 4894, 3040, 235292, - 235248, 235274, 235274, 235274, 728, 235274, 235248, 235284, 235308, - 235308, 235248, 235274, 235318, 235315, 235248, 235274, 235310, 235318, - 235248, 235284, 235318, 235248, 235274, 235284, 235321, 235248, 235274, - 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, - 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, - 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, - 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, - 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, - 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, - 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, - 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, - 235284, 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, - 235321, 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, - 235248, 235274, 235284, 235321, 235248, 235274, 235284, 235321, 235289, - 4284, 8277, 235292, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0], device='cuda:0') -(Pdb) subtask_tokens[0] -tensor([ 2, 28040, 7762, 14574, 6643, 9010, 37901, 21810, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - device='cuda:0') -(Pdb) actions.shape -torch.Size([4, 50, 32]) -(Pdb) actions.mean() -tensor(0.0143, device='cuda:0') -(Pdb) - - - - -Inference: diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 60ea6be87..c64e6f241 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -37,6 +37,9 @@ class PI05Config(PreTrainedConfig): # Shorter state and action vectors will be padded to these dimensions max_state_dim: int = 32 max_action_dim: int = 32 + max_action_tokens: int = 32 + fast_vocab_size: int = 2048 + # Flow matching parameters: see openpi `PI0Pytorch` num_inference_steps: int = 10 diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index c9e22c0a1..561324de3 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -537,6 +537,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + # FAST action token embedding and prediction head + self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width) + self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size) + + # Apply dtype conversion to FAST layers to match model precision + if config.dtype == "bfloat16": + self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16) + self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16) + elif config.dtype == "float32": + self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32) + self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32) + # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False @@ -592,6 +604,194 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` result = result.to(dtype=dtype) return result + def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize): + """Create custom 2D attention mask for the new attention pattern. + + Attention rules: + - Images + Language: bidirectional among themselves, don't attend to subtask or FAST + - Subtask: attend to images + language, causal among themselves, don't attend to FAST + - FAST: attend to images + language + subtask, causal among themselves + + Args: + att_mask_segments: List of (type, length) tuples + pad_masks: Padding masks [B, total_seq_len] + bsize: Batch size + + Returns: + att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len] + """ + total_len = sum(length for _, length in att_mask_segments) + device = pad_masks.device + + # Initialize attention mask as False (cannot attend) + att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) + + # Track positions for each segment + positions = [] + current_pos = 0 + for seg_type, seg_len in att_mask_segments: + positions.append((seg_type, current_pos, current_pos + seg_len)) + current_pos += seg_len + + # Apply attention rules + for i, (query_type, query_start, query_end) in enumerate(positions): + for j, (key_type, key_start, key_end) in enumerate(positions): + # Images and Language can attend to each other bidirectionally + if query_type in ['image', 'language'] and key_type in ['image', 'language']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # Subtask tokens attend to images + language + elif query_type == 'subtask' and key_type in ['image', 'language']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # Subtask tokens attend causally to themselves + elif query_type == 'subtask' and key_type == 'subtask': + # Create causal mask for subtask tokens + subtask_len = query_end - query_start + causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # FAST tokens attend to images + language + subtask + elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']: + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # FAST tokens attend causally to themselves + elif query_type == 'fast' and key_type == 'fast': + fast_len = query_end - query_start + causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # Apply padding masks + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + return att_2d_masks + + def visualize_attention_mask( + self, + att_mask_segments, + att_2d_masks, + save_path, + batch_idx=0, + dpi=150, + max_display_tokens=None + ): + """Visualize the attention mask with labeled segments. + + Args: + att_mask_segments: List of (type, length) tuples defining the segments + att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len] + save_path: Path where to save the visualization image + batch_idx: Which batch item to visualize (default: 0) + dpi: DPI for the saved image (default: 150) + max_display_tokens: Maximum number of tokens to display (for very long sequences) + """ + try: + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + from matplotlib.colors import LinearSegmentedColormap + except ImportError: + logging.warning("matplotlib not available, skipping attention mask visualization") + return + + # Extract the mask for the specified batch + mask = att_2d_masks[batch_idx].cpu().float().numpy() + + # If sequence is too long, downsample for visualization + if max_display_tokens is not None and mask.shape[0] > max_display_tokens: + # Simple downsampling by taking every Nth token + step = mask.shape[0] // max_display_tokens + mask = mask[::step, ::step] + # Adjust segments accordingly + att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments] + + # Calculate positions for each segment + positions = [] + current_pos = 0 + for seg_type, seg_len in att_mask_segments: + positions.append((seg_type, current_pos, current_pos + seg_len)) + current_pos += seg_len + + # Create figure + fig, ax = plt.subplots(figsize=(12, 10)) + + # Create custom colormap: white for False (no attention), blue for True (attention) + colors = ['white', '#2E86AB'] + n_bins = 2 + cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins) + + # Display the mask + im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label('Attention Enabled', rotation=270, labelpad=20) + cbar.set_ticks([0.25, 0.75]) + cbar.set_ticklabels(['No', 'Yes']) + + # Define colors for each segment type + segment_colors = { + 'image': '#A23B72', + 'language': '#F18F01', + 'subtask': '#C73E1D', + 'fast': '#6A994E' + } + + # Draw segment boundaries and labels + for seg_type, start, end in positions: + color = segment_colors.get(seg_type, '#666666') + + # Draw vertical lines for columns (keys) + ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7) + ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7) + + # Draw horizontal lines for rows (queries) + ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7) + ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7) + + # Add labels at the top + mid_pos = (start + end) / 2 + ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})", + ha='center', va='top', fontsize=10, fontweight='bold', color=color) + + # Add labels on the left + ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})", + ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0) + + # Set axis labels + ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold') + ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold') + ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)', + fontsize=14, fontweight='bold', pad=20) + + # Create legend for segment types + legend_patches = [] + attention_rules = { + 'image': 'Bidirectional with lang', + 'language': 'Bidirectional with images', + 'subtask': 'Attends to img+lang, causal self', + 'fast': 'Attends to all, causal self' + } + for seg_type, color in segment_colors.items(): + if any(seg[0] == seg_type for seg in att_mask_segments): + rule = attention_rules.get(seg_type, '') + legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}')) + + ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0), + framealpha=0.9, fontsize=9) + + # Adjust layout and save + plt.tight_layout() + + # Ensure the directory exists + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + plt.savefig(save_path, dpi=dpi, bbox_inches='tight') + plt.close() + + logging.info(f"Attention mask visualization saved to: {save_path}") + def sample_noise(self, shape, device): return torch.normal( mean=0.0, @@ -607,10 +807,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset return time.to(dtype=torch.float32, device=device) - + def embed_prefix( - self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + self, + images, + img_masks, + tokens, + subtask_tokens, + masks, + subtask_masks, + fast_action_tokens=None, + fast_action_masks=None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer. Args: @@ -619,17 +827,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` tokens: Language instruction tokens subtask_tokens: Subtask tokens to predict (can be None for inference) masks: Attention masks for tokens + fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs + fast_action_masks: Padding masks for FAST action tokens (can be None) Returns: - embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)] + embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)] pad_masks: Padding masks - att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided) + att_masks: Custom 2D attention mask implementing the required pattern total_T_images: Total number of image tokens + num_subtask_embs: Number of subtask token embeddings + num_fast_embs: Number of FAST action token embeddings """ embs = [] pad_masks = [] - att_masks = [] + att_mask_segments = [] # Store info about each segment for custom mask creation total_T_images = 0 + num_subtask_embs = 0 + num_fast_embs = 0 # Process images for img, img_mask in zip(images, img_masks, strict=True): @@ -642,7 +856,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - att_masks += [0] * num_img_embs # Images can attend to all previous tokens + att_mask_segments.append(('image', num_img_embs)) total_T_images += num_img_embs # Process language instruction tokens @@ -656,7 +870,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` pad_masks.append(masks) num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens) + att_mask_segments.append(('language', num_lang_embs)) # Process subtask tokens if provided (these are predicted, so use causal masking) if subtask_tokens is not None: @@ -672,18 +886,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` pad_masks.append(subtask_masks) num_subtask_embs = subtask_emb.shape[1] - # Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens, - # and previous subtask tokens - att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens + att_mask_segments.append(('subtask', num_subtask_embs)) + # Process FAST action tokens if provided (these are discrete token IDs) + if fast_action_tokens is not None: + def fast_action_embed_func(fast_action_tokens): + fast_emb = self.fast_action_embedding(fast_action_tokens) + fast_emb_dim = fast_emb.shape[-1] + return fast_emb * math.sqrt(fast_emb_dim) + + fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) + embs.append(fast_action_emb) + + # Use provided mask or create default (all valid) + if fast_action_masks is not None: + fast_pad_mask = fast_action_masks + else: + bsize = fast_action_tokens.shape[0] + num_fast_embs = fast_action_tokens.shape[1] + fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device) + + num_fast_embs = fast_action_tokens.shape[1] + pad_masks.append(fast_pad_mask) + att_mask_segments.append(('fast', num_fast_embs)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # Create custom 2D attention mask + # Attention rules: + # - Images + Language: bidirectional among themselves, don't attend to subtask or FAST + # - Subtask: attend to images + language, causal among themselves, don't attend to FAST + # - FAST: attend to images + language + subtask, causal among themselves + att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize) - bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0]) + # # Optionally visualize the attention mask + # self.visualize_attention_mask( + # att_mask_segments=att_mask_segments, + # att_2d_masks=att_masks, + # save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png", + # batch_idx=0, + # max_display_tokens=512 # Limit display for very long sequences + # ) - return embs, pad_masks, att_masks, total_T_images + return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs def embed_suffix(self, noisy_actions, timestep): """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" @@ -732,8 +977,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - # loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions) - def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor: + # loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks) + def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss. Args: @@ -743,7 +988,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` high_level_task_masks: Attention masks for high_level_task subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup") subtask_masks: Attention masks for subtask_tokens - actions: Ground truth actions + actions: Ground truth actions [B, chunk_size, action_dim] + fast_action_tokens: Discrete action token IDs [B, max_action_tokens] + fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens] noise: Optional noise for flow matching time: Optional time for flow matching """ @@ -757,75 +1004,183 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - # Embed prefix (images + high_level_task + subtask_tokens) - # Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict - prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix( - images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks + # Initialize FAST loss to 0 (will be computed only if FAST tokens are provided) + fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype) + + # ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ========== + # Only run this pass if FAST action tokens are provided + if fast_action_tokens is not None and fast_action_masks is not None: + # Embed prefix (images + high_level_task + subtask_tokens + FAST tokens) + # FAST tokens are provided as discrete token IDs + prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix( + images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks, + fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks + ) + + # Convert embeddings to bfloat16 if needed for the model + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16) + + # Prepare attention masks for prefix pass with FAST tokens + position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1 + att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype) + + # Forward pass through paligemma for subtask + FAST prediction + (prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_prefix_with_fast_4d, + position_ids=position_ids_prefix_with_fast, + past_key_values=None, + inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None + use_cache=False, + adarms_cond=[None, None], + ) + + # LM HEAD → SUBTASK LOGITS + lm_head = self.paligemma_with_expert.paligemma.lm_head + logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab) + + # Extract logits for subtask token prediction + T_high_level_task = high_level_task.size(1) + T_subtask = subtask_tokens.size(1) + start_index = total_T_images + T_high_level_task + end_index = start_index + T_subtask + logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab) + + targets = subtask_tokens # (B, T_subtask) + # Compute cross-entropy loss for subtask + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) + targets_flat = targets.reshape(-1) + loss_per_token = loss_fct(logits_flat, targets_flat) + loss_per_token = loss_per_token.reshape(targets.shape) + masked_loss = loss_per_token * subtask_masks.float() + subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1) + + # Extract outputs for FAST action token prediction and compute auxiliary loss + # FAST outputs start after subtask tokens + # Similar to subtask, we use autoregressive prediction where position i predicts token i+1 + fast_start_index = end_index + fast_end_index = fast_start_index + num_fast_embs + + # Get logits for FAST action tokens using the FAST LM head + fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size) + + # Extract logits for FAST token prediction (autoregressive: position i predicts token i+1) + # - Position (fast_start_index-1) predicts fast_action_tokens[0] + # - Position (fast_start_index) predicts fast_action_tokens[1], etc. + fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size) + + # Compute cross-entropy loss for FAST action tokens + fast_targets = fast_action_tokens # (B, max_action_tokens) + loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none') + fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size) + fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens) + + fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens) + fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens) + + # Apply mask and compute mean loss over valid tokens + masked_fast_loss = fast_loss_per_token * fast_action_masks.float() + fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1) + else: + # If no FAST tokens provided, compute subtask loss without FAST tokens + # This is the fallback for backward compatibility + prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix( + images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks, + fast_action_tokens=None, fast_action_masks=None + ) + + # Convert embeddings to bfloat16 if needed for the model + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16) + + position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1 + att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype) + + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_prefix_4d, + position_ids=position_ids_prefix, + past_key_values=None, + inputs_embeds=[prefix_embs_for_subtask, None], + use_cache=False, + adarms_cond=[None, None], + ) + + lm_head = self.paligemma_with_expert.paligemma.lm_head + logits = lm_head(prefix_out) + + T_high_level_task = high_level_task.size(1) + T_subtask = subtask_tokens.size(1) + start_index = total_T_images + T_high_level_task + end_index = start_index + T_subtask + logits_subtask = logits[:, start_index-1:end_index-1, :] + + targets = subtask_tokens + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) + targets_flat = targets.reshape(-1) + loss_per_token = loss_fct(logits_flat, targets_flat) + loss_per_token = loss_per_token.reshape(targets.shape) + masked_loss = loss_per_token * subtask_masks.float() + subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1) + + # ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ========== + # Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens) + prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix( + images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks, + fast_action_tokens=None, fast_action_masks=None ) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) - # Prepare attention masks for prefix-only pass (for subtask token prediction) - att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1 - att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype) - - # prefix-only transformer run for subtask token prediction - (prefix_out, _), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_prefix_4d, - position_ids=position_ids_prefix, - past_key_values=None, - inputs_embeds=[prefix_embs, None], # SUFFIX = None - use_cache=False, - adarms_cond=[None, None], - ) - - # LM HEAD → SUBTASK LOGITS - # prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask - lm_head = self.paligemma_with_expert.paligemma.lm_head - logits = lm_head(prefix_out) # (B, T_prefix, vocab) - - # Extract logits for subtask token prediction - # In autoregressive modeling, output at position i predicts token at position i+1 - # So we take logits from one position earlier: - # - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0] - # - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc. - T_high_level_task = high_level_task.size(1) - T_subtask = subtask_tokens.size(1) - start_index = total_T_images + T_high_level_task - end_index = start_index + T_subtask - logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab) - - targets = subtask_tokens # (B, T_subtask) - # Compute cross-entropy loss - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - # Reshape for loss computation - logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab) - targets_flat = targets.reshape(-1) # (B*T_subtask) - - loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask) - loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask) - - # Apply mask and compute mean loss over valid tokens - masked_loss = loss_per_token * subtask_masks.float() - subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1) - # Convert embeddings to bfloat16 if needed for the model if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) - prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16) - # Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - # Prepare attention masks for full forward pass (prefix + suffix) - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + # For the flow matching pass, we need custom attention where: + # - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention) + # - suffix attends to all prefix + causal to itself + # We'll construct this by extending prefix_att_masks_no_fast to include suffix + + # prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len] + # We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len] + + bsize = prefix_pad_masks_no_fast.shape[0] + prefix_len = prefix_pad_masks_no_fast.shape[1] + suffix_len = suffix_pad_masks.shape[1] + total_len = prefix_len + suffix_len + device = prefix_pad_masks_no_fast.device + + # Create full attention mask + full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) + + # Copy prefix attention pattern + full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast + + # Suffix attends to all prefix + full_att_2d_masks[:, prefix_len:, :prefix_len] = True + + # Suffix has causal attention among itself + suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device)) + full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :] + + # Apply padding masks + pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1) + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + full_att_2d_masks = full_att_2d_masks & pad_2d_masks + position_ids = torch.cumsum(pad_masks, dim=1) - 1 - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype) + att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( @@ -836,11 +1191,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` use_cache=False, adarms_cond=[None, adarms_cond], ) - # prefix_out to be used for the language head return suffix_out suffix_out = self._apply_checkpoint( - forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) suffix_out = suffix_out[:, -self.config.chunk_size :] @@ -856,79 +1210,81 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return { "flow_loss": fm_loss, "subtask_loss": subtask_loss, - "loss": 10 * fm_loss.mean() + subtask_loss, + "fast_loss": fast_loss, + "loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner } - + @torch.no_grad() def _generate_subtask_tokens( self, images, img_masks, tokens, masks, tokenizer, max_length, device ): - """Generate subtask tokens autoregressively using next token prediction.""" bsize = tokens.shape[0] - - # Get lm_head for token generation lm_head = self.paligemma_with_expert.paligemma.lm_head - # Embed prefix without subtask tokens first - prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix( - images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None + prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix( + images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None, + fast_action_tokens=None, fast_action_masks=None ) - - # Initialize generated tokens list - start with BOS token or first token after instruction - # For PaliGemma, we'll start generation and accumulate tokens + generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device) + # tracking mask: False = still generating, True = finished + finished = torch.zeros(bsize, dtype=torch.bool, device=device) + for t in range(max_length): - # Prepare attention masks for current prefix - att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1 - att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype) - - # Forward pass through model to get logits + att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) + (prefix_out, _), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_prefix_4d, + attention_mask=att_2d_prefix_4d, position_ids=position_ids_prefix, - past_key_values=None, inputs_embeds=[prefix_embs, None], - use_cache=False, - adarms_cond=[None, None], + # ... ) - # Get logits from the last position - logits = lm_head(prefix_out) # (B, T_prefix, vocab) - next_token_logits = logits[:, -1, :] # (B, vocab) + logits = lm_head(prefix_out) + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) # (B,) - # Greedy decoding - take the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) # (B,) + # 1. if a row was already finished, force the next token to be PAD (0) + next_token = torch.where(finished, torch.tensor(0, device=device), next_token) - # Store generated token + # 2. store the token generated_tokens[:, t] = next_token - # Check for EOS token - if all batches have generated EOS, stop + # 3. update the finished mask if tokenizer.eos_token_id is not None: - if (next_token == tokenizer.eos_token_id).all(): - break + finished |= (next_token == tokenizer.eos_token_id) - # Embed the generated token and append to prefix - next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1) + # 4. break only if everyone is finished + if finished.all(): + break + + next_token_unsqueezed = next_token.unsqueeze(1) def next_token_embed_func(next_token_unsqueezed): next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed) - next_emb_dim = next_emb.shape[-1] - return next_emb * math.sqrt(next_emb_dim) + return next_emb * math.sqrt(next_emb.shape[-1]) next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed) - # Append to prefix embeddings + # update embeddings prefix_embs = torch.cat([prefix_embs, next_emb], dim=1) - # Update masks - new token is valid and uses causal attention + # update padding masks prefix_pad_masks = torch.cat([ prefix_pad_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device) ], dim=1) - prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) - + + # update attention masks + old_seq_len = prefix_att_masks.shape[1] + new_seq_len = old_seq_len + 1 + new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device) + new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks + new_att_masks[:, -1, :] = prefix_pad_masks + prefix_att_masks = new_att_masks + return generated_tokens @torch.no_grad() # see openpi `sample_actions` (slightly adapted) @@ -978,13 +1334,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool) # During inference, we don't have subtask_tokens yet, so pass None - prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix( - images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks + # Also no FAST tokens during inference + prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix( + images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks, + fast_action_tokens=None, fast_action_masks=None ) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + + # Convert embeddings to bfloat16 if needed for the model + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # prefix_att_masks is already a 2D attention mask from embed_prefix prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype) + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( @@ -1209,7 +1575,7 @@ class PI05Policy(PreTrainedPolicy): print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) if missing_keys: print(f"Missing keys when loading state dict: {len(missing_keys)} keys") @@ -1442,19 +1808,26 @@ class PI05Policy(PreTrainedPolicy): actions = self.prepare_action(batch) # Decode and print ground truth subtask tokens during training - if self.tokenizer is not None and self.training: - bsize = subtask_tokens.shape[0] - for i in range(bsize): - # Remove padding tokens (0) and special tokens - valid_tokens = subtask_tokens[i][subtask_masks[i].bool()] - if len(valid_tokens) > 0: - decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True) - # print(f"[Training] Ground truth subtask {i}: {decoded_text}") + # if self.tokenizer is not None and self.training: + # bsize = subtask_tokens.shape[0] + # for i in range(bsize): + # # Remove padding tokens (0) and special tokens + # valid_tokens = subtask_tokens[i][subtask_masks[i].bool()] + # # if len(valid_tokens) > 0: + # # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True) + # # print(f"[Training] Ground truth subtask {i}: {decoded_text}") + # Get FAST action tokens from batch + fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens) + fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens) # Compute loss (no separate state needed for PI05) # high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:") # subtask_tokens = subtask tokens to predict (e.g., "pick up the cup") - loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions) + # fast_action_tokens = discrete action token IDs to predict + loss_dict = self.model.forward( + images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, + fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks + ) # Extract the total loss loss = loss_dict["loss"] @@ -1464,6 +1837,7 @@ class PI05Policy(PreTrainedPolicy): "loss": loss.item(), "flow_loss": loss_dict["flow_loss"].mean().item(), "subtask_loss": loss_dict["subtask_loss"].item(), + "fast_loss": loss_dict["fast_loss"].item(), } return loss, detailed_loss_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 1b1fcf047..65ec6244c 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -33,6 +33,7 @@ from lerobot.processor import ( ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, + ActionTokenizerProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, ) @@ -158,7 +159,6 @@ def make_pi05_pre_post_processors( Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - # Add remaining processors input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one @@ -177,6 +177,9 @@ def make_pi05_pre_post_processors( padding_side="right", padding="max_length", ), + ActionTokenizerProcessorStep( + tokenizer_name="physical-intelligence/fast", + ), DeviceProcessorStep(device=config.device), ] @@ -186,7 +189,7 @@ def make_pi05_pre_post_processors( ), DeviceProcessorStep(device="cpu"), ] - + return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index be11ac1af..540f74ef9 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -75,7 +75,7 @@ from .policy_robot_bridge import ( RobotActionToPolicyActionProcessorStep, ) from .rename_processor import RenameObservationsProcessorStep -from .tokenizer_processor import TokenizerProcessorStep +from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep __all__ = [ "ActionProcessorStep", diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 18b51a9ce..59afddab1 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -15,10 +15,13 @@ # limitations under the License. """ -This script defines a processor for tokenizing natural language instructions from an environment transition. +This script defines processors for tokenizing data from an environment transition. -It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into -token IDs and attention masks, which are then added to the observation dictionary. +It includes: +- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert + task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary. +- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer) + to tokenize action tensors into discrete token IDs for action modeling. """ from __future__ import annotations @@ -30,6 +33,8 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import ( + ACTION_TOKEN_MASK, + ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK, OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS, @@ -40,12 +45,13 @@ from lerobot.utils.constants import ( from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey -from .pipeline import ObservationProcessorStep, ProcessorStepRegistry +from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer else: + AutoProcessor = None AutoTokenizer = None @@ -423,3 +429,233 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) return features + + +@dataclass +@ProcessorStepRegistry.register(name="action_tokenizer_processor") +class ActionTokenizerProcessorStep(ActionProcessorStep): + """ + Processor step to tokenize action data using a fast action tokenizer. + + This step takes action tensors from an `EnvTransition`, tokenizes them using + a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer), + and returns the tokenized action. + + Requires the `transformers` library to be installed. + + Attributes: + tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast"). + tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored. + trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers). + action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. + """ + + tokenizer_name: str | None = None + tokenizer: Any | None = None + trust_remote_code: bool = True + max_action_tokens: int = 32 + # Internal tokenizer instance (not part of the config) + action_tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the action tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep." + ) + + if self.tokenizer is not None: + # Use provided tokenizer object directly + self.action_tokenizer = self.tokenizer + elif self.tokenizer_name is not None: + if AutoProcessor is None: + raise ImportError("AutoProcessor is not available") + self.action_tokenizer = AutoProcessor.from_pretrained( + self.tokenizer_name, trust_remote_code=self.trust_remote_code + ) + else: + raise ValueError( + "Either 'tokenizer' or 'tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies action tokenization to the transition. + + This overrides the base class to handle both tokens and mask. + + Args: + transition: The input transition with action data. + + Returns: + The processed transition with tokenized actions and mask in complementary data. + """ + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is None: + raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.") + + # Tokenize and get both tokens and mask + tokens, mask = self._tokenize_action(action) + + # Store mask in complementary data + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if complementary_data is None: + complementary_data = {} + complementary_data[ACTION_TOKEN_MASK] = mask + complementary_data[ACTION_TOKENS] = tokens + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Tokenizes the action tensor and creates a mask. + + Args: + action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,) + + Returns: + A tuple of (tokens, mask) where: + - tokens: Tensor of token IDs with shape (B, max_action_tokens) + - mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding + """ + if action is None: + raise ValueError("Action cannot be None") + + # Get the device and dtype of the input action + device = action.device if isinstance(action, torch.Tensor) else None + + # Handle single sample (add batch dimension) + single_sample = action.dim() == 1 + if single_sample: + action = action.unsqueeze(0) + + batch_size = action.shape[0] + + # Tokenize the action batch + # The fast tokenizer expects action data and returns token IDs + tokens_list = [] + masks_list = [] + + for i in range(batch_size): + # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) + action_cpu = action[i:i+1].cpu() + tokens = self.action_tokenizer(action_cpu) + + # Convert to numpy array if it's a list + if isinstance(tokens, list): + tokens = torch.tensor(tokens, dtype=torch.long, device=action.device) + elif not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens, dtype=torch.long, device=action.device) + else: + # Move tokens back to the same device as input action + tokens = tokens.to(device=action.device) + + # Flatten to 1D if needed + if tokens.dim() > 1: + tokens = tokens.flatten() + + # Truncate or pad to max_action_tokens + if len(tokens) > self.max_action_tokens: + tokens = tokens[:self.max_action_tokens] + mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) + else: + mask = torch.cat([ + torch.ones(len(tokens), dtype=torch.bool, device=action.device), + torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device) + ]) + # Pad tokens with zeros + tokens = torch.nn.functional.pad( + tokens, + (0, self.max_action_tokens - len(tokens)), + value=0 + ) + + tokens_list.append(tokens) + masks_list.append(mask) + + # Stack into batched tensors + tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) + masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) + + # Remove batch dimension if input was single sample + if single_sample: + tokens_batch = tokens_batch.squeeze(0) + masks_batch = masks_batch.squeeze(0) + + # Move to the same device as the input + if device is not None: + tokens_batch = tokens_batch.to(device) + masks_batch = masks_batch.to(device) + + return tokens_batch, masks_batch + + def action(self, action: torch.Tensor) -> torch.Tensor: + """ + This method is not used since we override __call__. + Required by ActionProcessorStep ABC. + """ + tokens, _ = self._tokenize_action(action) + return tokens + + def get_config(self) -> dict[str, Any]: + """ + Returns the serializable configuration of the processor. + + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. + """ + config = { + "trust_remote_code": self.trust_remote_code, + "max_action_tokens": self.max_action_tokens, + } + + # Only save tokenizer_name if it was used to create the tokenizer + if self.tokenizer_name is not None and self.tokenizer is None: + config["tokenizer_name"] = self.tokenizer_name + + return config + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates feature definitions to reflect tokenized actions. + + This updates the policy features dictionary to indicate that the action + has been tokenized into a sequence of token IDs with shape (max_action_tokens,). + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + # Update the action feature to reflect the tokenized shape + # The action is now a sequence of token IDs + if PipelineFeatureType.ACTION in features: + # Replace the action feature with the tokenized version + features[PipelineFeatureType.ACTION] = { + ACTION_TOKENS: PolicyFeature( + type=FeatureType.SEQUENCE, # Token sequence + shape=(self.max_action_tokens,) + ) + } + + return features diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 02401068c..196580984 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -90,6 +90,7 @@ def update_policy( # Let accelerator handle mixed precision with accelerator.autocast(): loss, output_dict = policy.forward(batch) + action = policy.select_action(batch) breakpoint() # TODO(rcadene): policy.unnormalize_outputs(out_dict) diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index c8e19eb56..f249f5600 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -33,6 +33,8 @@ OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask" OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens" OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask" ACTION = "action" +ACTION_TOKENS = ACTION + ".tokens" +ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done"