Compare commits

..

11 Commits

Author SHA1 Message Date
Jade Choghari 995a46b302 make it work 2025-12-29 17:34:18 +01:00
Jade Choghari 23d4846423 more quick fixes 2025-12-28 07:33:24 +00:00
Jade Choghari 7d897daeb2 add more changges 2025-12-27 21:15:30 +00:00
Jade Choghari 7556c7fd70 more changes 2025-12-27 20:26:23 +00:00
Jade Choghari 4434c863b4 fix training 2025-12-27 10:43:00 +00:00
Jade Choghari 4b40153c32 align fast more 2025-12-26 17:24:39 +00:00
Jade Choghari f0923e5c86 remove brkpt 2025-12-26 06:46:27 +00:00
Jade Choghari 8edd544bbe detoknize action at policy level 2025-12-26 06:45:38 +00:00
Jade Choghari e682ef05f9 make fast work 2025-12-25 20:59:32 +00:00
Jade Choghari 9b5ac4387c add more changes 2025-12-23 13:11:18 +00:00
Jade Choghari 5781754c30 add pifast 2025-12-22 11:36:53 +01:00
20 changed files with 6022 additions and 4126 deletions
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

+18
View File
@@ -0,0 +1,18 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
dataset = LeRobotDataset(repo_id="lerobot/libero")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
print(batch.keys())
breakpoint()
+11 -10
View File
@@ -2,25 +2,26 @@
"repo_id": "local",
"vocab_size": 1024,
"scale": 10.0,
"encoded_dims": "0:15",
"encoded_dims": "0:7",
"encoded_dim_ranges": [
[
0,
15
7
]
],
"total_encoded_dims": 15,
"total_encoded_dims": 7,
"delta_dims": null,
"delta_dim_list": null,
"use_delta_transform": false,
"state_key": "observation.state",
"action_horizon": 50,
"num_training_chunks": 4900,
"normalization_mode": "QUANTILES",
"action_horizon": 10,
"num_training_chunks": 25065,
"compression_stats": {
"compression_ratio": 15.85791309863622,
"mean_token_length": 47.295,
"p99_token_length": 90.0,
"min_token_length": 9.0,
"max_token_length": 109.0
"compression_ratio": 3.464660463274599,
"mean_token_length": 20.204,
"p99_token_length": 36.00999999999999,
"min_token_length": 5.0,
"max_token_length": 38.0
}
}
+3 -3
View File
@@ -1,11 +1,11 @@
{
"action_dim": 15,
"action_dim": 7,
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"min_token": -71,
"min_token": -32,
"processor_class": "UniversalActionProcessor",
"scale": 10.0,
"time_horizon": 50,
"time_horizon": 10,
"vocab_size": 1024
}
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -162,7 +162,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.tasks_high_level = load_tasks_high_level(self.root)
# self.tasks_high_level = load_tasks_high_level(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -40,6 +40,8 @@ class PI05Config(PreTrainedConfig):
max_action_tokens: int = 32
fast_vocab_size: int = 2048
# FAST-only mode: train with only discrete action token prediction (no flow matching, no subtask)
fast_only: bool = False
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
+781 -25
View File
@@ -21,8 +21,10 @@ from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from scipy.fftpack import idct
from torch import Tensor, nn
from typing_extensions import Unpack
@@ -536,18 +538,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# # FAST action token embedding and prediction head
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# FAST action token embedding and prediction head
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# Apply dtype conversion to FAST layers to match model precision
if config.dtype == "bfloat16":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
elif config.dtype == "float32":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
)
# # Apply dtype conversion to FAST layers to match model precision
# if config.dtype == "bfloat16":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
# elif config.dtype == "float32":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
@@ -1213,6 +1220,514 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"fast_loss": fast_loss,
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
}
def embed_prefix_fast(
self,
images,
img_masks,
tokens,
masks,
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
"""Embed images, language tokens, and FAST action tokens for FAST-only mode.
This is a simplified version of embed_prefix without subtask tokens.
Attention pattern:
- Images + Language: bidirectional among themselves
- FAST: attend to images + language, causal among themselves
Args:
images: List of image tensors
img_masks: List of image masks
tokens: Language instruction tokens
masks: Attention masks for tokens
fast_action_tokens: FAST action tokens (discrete token IDs)
fast_action_masks: Padding masks for FAST action tokens
Returns:
embs: Concatenated embeddings [images, tokens, fast_action_tokens]
pad_masks: Padding masks
att_masks: 2D attention mask
total_T_images: Total number of image tokens
num_fast_embs: Number of FAST action token embeddings
"""
embs = []
pad_masks = []
att_mask_segments = []
total_T_images = 0
num_fast_embs = 0
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_mask_segments.append(('image', num_img_embs))
total_T_images += num_img_embs
# Process language instruction tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_mask_segments.append(('language', num_lang_embs))
# Process FAST action tokens (discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
if fast_action_masks is not None:
fast_pad_mask = fast_action_masks
else:
bsize = fast_action_tokens.shape[0]
num_fast_embs = fast_action_tokens.shape[1]
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_pad_mask)
att_mask_segments.append(('fast', num_fast_embs))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
# Create custom 2D attention mask for FAST-only mode:
# - Images + Language: bidirectional among themselves
# - FAST: attend to images + language, causal among themselves
att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize)
return embs, pad_masks, att_masks, total_T_images, num_fast_embs
def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for FAST-only mode.
Attention rules:
- Images + Language: bidirectional among themselves
- FAST: attend to images + language, causal among themselves
"""
total_len = sum(length for _, length in att_mask_segments)
device = pad_masks.device
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
for i, (query_type, query_start, query_end) in enumerate(positions):
for j, (key_type, key_start, key_end) in enumerate(positions):
# Images and Language can attend to each other bidirectionally
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend to images + language
elif query_type == 'fast' and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend causally to themselves
elif query_type == 'fast' and key_type == 'fast':
fast_len = query_end - query_start
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# Apply padding masks
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def forward_fast_only(
self,
images,
img_masks,
tokens,
masks,
fast_action_tokens,
fast_action_masks,
) -> dict:
"""Forward pass for FAST-only mode (no flow matching, no subtask).
This implements the Pi0FAST training objective: predict next action token
using cross-entropy loss.
Args:
images: List of image tensors
img_masks: List of image masks
tokens: Language instruction tokens
masks: Attention masks for tokens
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
Returns:
Dictionary with 'fast_loss' and 'loss' keys
"""
if fast_action_tokens is None or fast_action_masks is None:
raise ValueError("fast_action_tokens and fast_action_masks are required for FAST-only mode")
# Embed prefix with FAST tokens
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, num_fast_embs = self.embed_prefix_fast(
images, img_masks, tokens, masks,
fast_action_tokens=fast_action_tokens,
fast_action_masks=fast_action_masks
)
# Convert embeddings to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# For next-token prediction, we input tokens [0:T-1] to predict tokens [1:T]
# So we remove the last token from input
# input_embs = prefix_embs[:, :-1]
# input_pad_masks = prefix_pad_masks[:, :-1]
# input_att_masks = prefix_att_masks[:, :-1, :-1]
input_embs = prefix_embs
input_pad_masks = prefix_pad_masks
input_att_masks = prefix_att_masks
position_ids = torch.cumsum(input_pad_masks, dim=1) - 1
att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype)
# Forward pass through paligemma (language model only, no action expert)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[input_embs, None], # No suffix/action expert
use_cache=False,
adarms_cond=[None, None],
)
# Get logits for FAST action tokens using the FAST LM head
# We only compute logits for the positions that predict FAST tokens
lm_head = self.paligemma_with_expert.paligemma.lm_head
# The FAST tokens start at position (total_T_images + num_lang_tokens)
# For next-token prediction:
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
# Targets are the FAST action tokens
fast_targets = fast_action_tokens # (B, num_fast_embs)
T_lang = masks.shape[1]
fast_start = total_T_images + T_lang
# Extract logits for FAST token prediction
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
# fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
fast_hidden = prefix_out[:, -fast_targets.shape[1]:, :]
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
# Shift left for next-step prediction and shift target
# logits[:, i] predicts targets[:, i+1]
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # Shift logits left
fast_targets = fast_targets[:, 1:] # Shift targets right
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
# from transformers import AutoTokenizer
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
# "google/paligemma-3b-pt-224",
# trust_remote_code=True,
# add_eos_token=True,
# add_bos_token=False
# )
# # remove
# decoded_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_targets
# ]
# corrected_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_logits_for_pred.argmax(dim=-1)
# ]
# breakpoint()
# Compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
fast_targets_flat = fast_targets.reshape(-1)
fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat)
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape)
# Apply mask and compute mean loss
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
return {
"fast_loss": fast_loss,
"loss": fast_loss,
}
@torch.no_grad()
def sample_actions_fast(
self,
images,
img_masks,
tokens,
masks,
max_decoding_steps=None,
temperature=0.0,
) -> torch.Tensor:
"""
Inefficient but safe autoregressive decoding for FAST tokens.
Matches the pattern of _generate_subtask_tokens.
"""
if max_decoding_steps is None:
max_decoding_steps = self.config.max_action_tokens
bsize = tokens.shape[0]
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
# add bos token after tokens
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
tokens = torch.cat([tokens, bos_token], dim=1)
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
images, img_masks, tokens, masks,
fast_action_tokens=None,
fast_action_masks=None
)
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
# 2. Decoding Loop (Re-computes full sequence every step)
for t in range(max_decoding_steps):
# Always re-calculate position IDs from the current pad mask (matches training)
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
# Full forward pass (No KV Cache)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
adarms_cond=[None, None],
)
# Predict next token from the very last sequence position
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
generated_action_tokens[:, t] = next_token.squeeze(-1)
# 3. Update Sequence for next iteration (unless it's the last step)
if t < max_decoding_steps - 1:
# Embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
# Append to embeddings
prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1)
# Update padding mask (New token is always valid/1)
prefix_pad_masks = torch.cat([
prefix_pad_masks,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1)
# Update 2D attention mask: Grow the matrix
old_len = prefix_att_masks.shape[1]
new_len = old_len + 1
new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device)
new_att_masks[:, :old_len, :old_len] = prefix_att_masks
# New token attends to all non-padding tokens in the updated sequence
new_att_masks[:, -1, :] = prefix_pad_masks
prefix_att_masks = new_att_masks
return generated_action_tokens
@torch.no_grad()
def sample_actions_fast_kv_cache(
self,
images,
img_masks,
tokens,
masks,
max_decoding_steps=None,
temperature=0.0,
) -> torch.Tensor:
"""
Efficient autoregressive decoding for FAST tokens using KV-caching.
Only computes the prefix once, then incrementally generates tokens.
"""
if max_decoding_steps is None:
max_decoding_steps = self.config.max_action_tokens
bsize = tokens.shape[0]
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
tokens = torch.cat([tokens, bos_token], dim=1)
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
images, img_masks, tokens, masks,
fast_action_tokens=None,
fast_action_masks=None
)
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
# 2. Initial forward pass to populate KV cache
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
# First forward pass with full prefix (caching enabled)
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=att_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
adarms_cond=[None, None],
)
# # Get BOS token and add it as the first token in action sequence
# bos_id = self._paligemma_tokenizer.bos_token_id
# bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
# # Embed BOS token
# bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
# bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
# if prefix_embs.dtype == torch.bfloat16:
# bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
# Track current sequence length for position IDs and maintain the padding mask
current_seq_len = prefix_embs.shape[1]
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
# # Update padding mask for BOS token: it's always valid
# current_pad_mask = torch.cat([
# current_pad_mask,
# torch.ones((bsize, 1), dtype=torch.bool, device=device)
# ], dim=1) # (B, seq_len+1)
# # Position ID for BOS token (continues from where prefix ended)
# bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# # Attention mask for BOS token: attends to all valid prefix positions
# bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
# bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
# # Forward pass with BOS token (reusing cached KVs from prefix)
# (bos_out, _), past_key_values = self.paligemma_with_expert.forward(
# attention_mask=bos_att_4d,
# position_ids=bos_position_id,
# past_key_values=past_key_values,
# inputs_embeds=[bos_token_emb, None],
# use_cache=True,
# adarms_cond=[None, None],
# )
# # Update sequence length to account for BOS token
# current_seq_len += 1
# Predict first action token from BOS token output
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
generated_action_tokens[:, 0] = next_token.squeeze(-1)
# 3. Incremental Decoding Loop (using KV cache)
for t in range(1, max_decoding_steps):
# Embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
# Update padding mask: new generated token is always valid
current_pad_mask = torch.cat([
current_pad_mask,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1) # (B, seq_len+1)
# Position ID for the new token (continues from where we left off)
new_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# For KV-cache: attention mask for the new token should only attend to valid positions
# Shape: (B, 1, past_len+1) where the new token attends to valid prefix + all generated tokens
new_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
att_4d_incremental = self._prepare_attention_masks_4d(new_att_mask_2d, dtype=next_token_emb.dtype)
# Forward pass with only the new token embedding (reusing cached KVs)
(new_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=att_4d_incremental,
position_ids=new_position_id,
past_key_values=past_key_values,
inputs_embeds=[next_token_emb, None],
use_cache=True,
adarms_cond=[None, None],
)
# Predict next token
last_logits = lm_head(new_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
generated_action_tokens[:, t] = next_token.squeeze(-1)
# Update sequence length
current_seq_len += 1
return generated_action_tokens
@torch.no_grad()
def _generate_subtask_tokens(
@@ -1480,6 +1995,34 @@ class PI05Policy(PreTrainedPolicy):
except Exception as e:
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
self.tokenizer = None
# Load FAST tokenizer for action detokenization (only if fast_only mode)
self.action_tokenizer = None
self._paligemma_tokenizer = None
self._fast_skip_tokens = 128
if config.fast_only:
try:
from transformers import AutoProcessor, AutoTokenizer
# Load FAST tokenizer
self.action_tokenizer = AutoProcessor.from_pretrained(
"jadechoghari/fast-libero-tokenizer-mean-std",
trust_remote_code=True
)
# Load PaliGemma tokenizer for token conversion
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
logging.info("Loaded FAST tokenizer for action detokenization")
except Exception as e:
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
logging.warning("Action tokens will be returned without detokenization")
self.reset()
@@ -1556,7 +2099,6 @@ class PI05Policy(PreTrainedPolicy):
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -1655,6 +2197,9 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if key == "model.paligemma_with_expert.paligemma.lm_head.weight" or key == "paligemma_with_expert.paligemma.lm_head.weight":
fixed_state_dict["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
@@ -1755,6 +2300,173 @@ class PI05Policy(PreTrainedPolicy):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
Args:
tokens: PaliGemma token IDs
Returns:
Action token IDs
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = True
) -> np.ndarray:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A numpy array representing the decoded actions.
"""
decoded_actions = []
for token in token_ids:
try:
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
if relaxed_decoding:
# expected sequence length
expected_seq_len = time_horizon * action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right
# apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
assert decoded_dct_coeff.shape == (
time_horizon,
action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
)
except Exception as e:
logging.warning(f"Error decoding tokens: {e}")
logging.warning(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
This method converts predicted action tokens from the model back to continuous action values
using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token
space, then decodes the action tokens to continuous values using DCT decoding.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
raise ValueError(
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
)
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# Convert token IDs to token strings
decoded_tokens = [
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
# Get the token sequence for "Action: " to remove it
action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False)
action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids)
action_prefix_len = len(action_prefix_tokens)
# Clean tokens by removing everything after the first "|" (end-of-action marker)
# and removing all occurrences of "Action: " token sequence
# assert that beginning contain "Action: "
for token_seq in decoded_tokens:
assert (
len(token_seq) >= 2
and token_seq[0] == "Action"
and token_seq[1] == ":"
), f"Token sequence does not start with ['Action', ':']: {token_seq}"
cleaned_tokens = []
for token_seq in decoded_tokens:
# Remove everything after "|"
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
# Remove all occurrences of "Action: " token sequence
i = 0
while i <= len(token_seq) - action_prefix_len:
if token_seq[i:i+action_prefix_len] == action_prefix_tokens:
# Found a match, remove it
token_seq = token_seq[:i] + token_seq[i+action_prefix_len:]
else:
i += 1
cleaned_tokens.append(token_seq)
# Convert token strings back to IDs
raw_action_tokens = [
torch.tensor(
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
dtype=torch.long,
device=tokens.device,
)
for token_seq in cleaned_tokens
]
# Convert PaliGemma tokens to action tokens
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
# Decode action tokens to continuous actions
actions = self.decode_actions_with_fast(
action_tokens,
time_horizon=action_horizon,
action_dim=action_dim
)
# Convert to tensor and return
actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device)
# Remove batch dimension if input was single sample
if single_sample:
actions_tensor = actions_tensor.squeeze(0)
return actions_tensor
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -1780,6 +2492,35 @@ class PI05Policy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
# FAST-only mode: use autoregressive decoding
if self.config.fast_only:
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Get optional parameters
temperature = kwargs.get("temperature", 0.0)
max_decoding_steps = 256
# Sample action tokens autoregressively
action_tokens = self.model.sample_actions_fast(
images, img_masks, tokens, masks,
max_decoding_steps=max_decoding_steps,
temperature=temperature,
)
# Detokenize action tokens to continuous actions
action_horizon = self.config.n_action_steps
action_dim = 7
continuous_actions = self.detokenize_actions(
action_tokens,
action_horizon=action_horizon,
action_dim=action_dim
)
return continuous_actions
# Full mode: use flow matching with optional subtask generation
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
@@ -1802,24 +2543,39 @@ class PI05Policy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
# Get FAST action tokens from batch
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
# FAST-only mode: only use discrete action token prediction
if self.config.fast_only:
# Use full language tokens (no separation into high_level_task and subtask)
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
if fast_action_tokens is None or fast_action_masks is None:
raise ValueError("FAST-only mode requires action.tokens and action.token_mask in the batch")
loss_dict = self.model.forward_fast_only(
images, img_masks, tokens, masks,
fast_action_tokens=fast_action_tokens,
fast_action_masks=fast_action_masks
)
loss = loss_dict["loss"]
detailed_loss_dict = {
"loss": loss.item(),
"fast_loss": loss_dict["fast_loss"].item(),
}
return loss, detailed_loss_dict
# Full mode: flow matching + subtask + FAST
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
# Decode and print ground truth subtask tokens during training
# if self.tokenizer is not None and self.training:
# bsize = subtask_tokens.shape[0]
# for i in range(bsize):
# # Remove padding tokens (0) and special tokens
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
# # if len(valid_tokens) > 0:
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
# Get FAST action tokens from batch
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
# Compute loss (no separate state needed for PI05)
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
+2 -2
View File
@@ -101,8 +101,8 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
cleaned_high_level_task = cleaned_high_level_tasks[i]
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
else:
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade
low_level_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
+6 -5
View File
@@ -2,7 +2,7 @@ export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit1 \
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit2 \
--job_name=pi0_training \
--policy.repo_id=jade_choghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
@@ -14,9 +14,10 @@ lerobot-train \
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=4 \
--batch_size=16 \
--policy.device=cuda \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05hi-training \
--policy.fast_only=true \
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
+3 -8
View File
@@ -1,18 +1,13 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
lerobot-train \
--dataset.repo_id=local\
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=32 \
--batch_size=4 \
--policy.device=cuda \
+9 -6
View File
@@ -1,9 +1,12 @@
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "local" \
--root "/fsx/jade_choghari/outputs/collect-data-pgen" \
--action_horizon 16 \
--encoded_dims "0:15" \
--action_horizon 50 \
--root /fsx/jade_choghari/data/libero \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--scale 10.0 \
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"
--push_to_hub \
--hub_repo_id jadechoghari/fast-libero-tokenizer-quantiles \
--normalization_mode QUANTILES \
# python train_fast_tokenizer.py --repo_id my_dataset
+141 -18
View File
@@ -15,6 +15,8 @@ from pathlib import Path
from transformers import AutoProcessor
import torch
from huggingface_hub import HfApi
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -39,6 +41,64 @@ def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: li
return delta_actions
def apply_normalization(
data: np.ndarray,
stats: dict[str, np.ndarray],
mode: NormalizationMode,
eps: float = 1e-8,
) -> np.ndarray:
"""Apply normalization to data based on the specified mode.
Args:
data: Data to normalize [N, H, D] or [D]
stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90)
mode: Normalization mode to apply
eps: Small epsilon for numerical stability
Returns:
Normalized data with the same shape as input
"""
if mode == NormalizationMode.IDENTITY:
return data
if mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean")
std = stats.get("std")
if mean is None or std is None:
raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats")
return (data - mean) / np.maximum(std, eps)
if mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min")
max_val = stats.get("max")
if min_val is None or max_val is None:
raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats")
denom = np.maximum(max_val - min_val, eps)
return 2.0 * (data - min_val) / denom - 1.0
if mode == NormalizationMode.QUANTILES:
q01 = stats.get("q01")
q99 = stats.get("q99")
if q01 is None or q99 is None:
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
denom = np.maximum(q99 - q01, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q01, q99)
return 2.0 * (clipped - q01) / denom - 1.0
if mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10")
q90 = stats.get("q90")
if q10 is None or q90 is None:
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
denom = np.maximum(q90 - q10, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q10, q90)
return 2.0 * (clipped - q10) / denom - 1.0
raise ValueError(f"Unsupported normalization mode: {mode}")
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
@@ -237,9 +297,13 @@ def main(
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = "observation.state",
normalization_mode: str = "QUANTILES",
vocab_size: int = 1024,
scale: float = 10.0,
output_dir: str | None = None,
push_to_hub: bool = False,
hub_repo_id: str | None = None,
hub_private: bool = False,
):
"""
Train FAST tokenizer for action encoding.
@@ -254,15 +318,29 @@ def main(
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
state_key: Dataset key for state observations (default: "observation.state")
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
vocab_size: FAST vocabulary size (BPE vocab size)
scale: DCT scaling factor (default: 10.0)
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
hub_private: Whether to create a private repository on the Hub
"""
# Load dataset
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id=repo_id, root=root)
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
# Parse normalization mode
try:
norm_mode = NormalizationMode(normalization_mode)
except ValueError:
raise ValueError(
f"Invalid normalization_mode: {normalization_mode}. "
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
)
print(f"Normalization mode: {norm_mode.value}")
# Parse encoded dimensions
encoded_dim_ranges = []
for range_str in encoded_dims.split(','):
@@ -317,13 +395,12 @@ def main(
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
# Apply normalization to encoded dimensions only
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
# This clips outliers and provides consistent [-1, 1] range for DCT compression
# Apply normalization to encoded dimensions
print(f"\nBefore normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
# Get normalization stats from dataset
norm_stats = dataset.meta.stats
if norm_stats is not None and "action" in norm_stats:
action_stats = norm_stats["action"]
@@ -334,19 +411,31 @@ def main(
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
if "q01" in action_stats and "q99" in action_stats:
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
# Extract stats for encoded dimensions only
encoded_stats = {}
for stat_name, stat_values in action_stats.items():
if isinstance(stat_values, (list, np.ndarray)):
stat_array = np.array(stat_values)
if len(stat_array) > max(encoded_dim_indices):
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
if encoded_stats:
print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):")
for stat_name, stat_values in encoded_stats.items():
print(f" {stat_name}: shape={stat_values.shape}, "
f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]")
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
for i, dim_idx in enumerate(encoded_dim_indices):
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
# Clip to quantile range and normalize to [-1, 1]
encoded_chunks = np.clip(encoded_chunks, q01, q99)
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
# Apply normalization based on mode
try:
encoded_chunks = apply_normalization(
encoded_chunks,
encoded_stats,
norm_mode,
eps=1e-8
)
print(f"\nApplied {norm_mode.value} normalization")
except ValueError as e:
print(f"Warning: {e}. Using raw actions without normalization.")
print(f"\nAfter normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
@@ -358,9 +447,9 @@ def main(
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
else:
print("Warning: q01/q99 stats not found, using raw actions")
print("Warning: Could not extract stats for encoded dimensions, using raw actions")
else:
print("Warning: No normalization stats found, using raw actions")
print("Warning: No normalization stats found in dataset, using raw actions")
print(f"Encoded chunks shape: {encoded_chunks.shape}")
@@ -394,6 +483,7 @@ def main(
'delta_dim_list': delta_dim_list,
'use_delta_transform': use_delta_transform,
'state_key': state_key,
'normalization_mode': norm_mode.value,
'action_horizon': action_horizon,
'num_training_chunks': len(encoded_chunks),
'compression_stats': compression_stats,
@@ -402,8 +492,41 @@ def main(
with open(output_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\nSaved FAST tokenizer to {output_path}")
print(f"\nSaved FAST tokenizer to {output_path}")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
# Push to Hugging Face Hub if requested
if push_to_hub:
# Determine the hub repository ID
if hub_repo_id is None:
hub_repo_id = output_path.name
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
print(f" Private: {hub_private}")
try:
# Use the tokenizer's push_to_hub method
tokenizer.push_to_hub(
repo_id=hub_repo_id,
private=hub_private,
commit_message=f"Upload FAST tokenizer trained on {repo_id}"
)
# Also upload the metadata.json file separately
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_path / "metadata.json"),
path_in_repo="metadata.json",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Upload tokenizer metadata"
)
print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}")
except Exception as e:
print(f"Error pushing to hub: {e}")
print(" Make sure you're logged in with `huggingface-cli login`")
if __name__ == "__main__":
+28
View File
@@ -0,0 +1,28 @@
#!/bin/bash
# FSDP training script for PI05 with aggressive memory optimization
# Use this for large models that OOM with standard DDP
accelerate launch --config_file /admin/home/jade_choghari/lerobot/fsdp_config.yaml \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fsdp \
--job_name=libero_training_fsdp \
--policy.repo_id=jade_choghari/pi05-fast-libero-fsdp \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=10 \
--batch_size=8 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=2000 \
--policy.scheduler_decay_steps=60000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=false \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training-fsdp
+24
View File
@@ -0,0 +1,24 @@
export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/pi05-base \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=1000 \
--policy.scheduler_decay_steps=30000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--rename_map='{
"observation.images.image1": "observation.images.base_0_rgb",
"observation.images.image2": "observation.images.left_wrist_0_rgb",
}' \
--policy.empty_cameras=1 \
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
@@ -0,0 +1,15 @@
#!/bin/bash
#SBATCH --job-name=pi05-train
#SBATCH --time=24:00:00
#SBATCH --qos=high
#SBATCH --gres=gpu:8
#SBATCH --mem=256G
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/jade_choghari/logs/%x-%j.out
#SBATCH --error=/fsx/jade_choghari/logs/%x-%j.err
srun \
--container-image=/fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh \
--container-mounts=/fsx/jade_choghari \
--container-workdir=$HOME/lerobot \
bash /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/train_multi.sh
+32 -19
View File
@@ -1,23 +1,36 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
accelerate launch --multi_gpu --num_processes=2 \
#!/bin/bash
set -euxo pipefail
# Source YOUR Miniforge conda (mounted from FSX)
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
conda activate lerobot
accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/pi05-base \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=4000 \
--policy.scheduler_decay_steps=100000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--batch_size=1 \
--policy.device=cpu
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.max_action_tokens=256 \
--rename_map='{
"observation.images.image1": "observation.images.base_0_rgb",
"observation.images.image2": "observation.images.left_wrist_0_rgb",
}' \
--policy.empty_cameras=1 \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training \
+533 -5
View File
@@ -29,7 +29,9 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from scipy.fft import idct
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import (
@@ -223,7 +225,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
task = self.get_task(self.transition)
if task is None:
raise ValueError("Task cannot be None")
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)
@@ -352,7 +353,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
else:
# If at max length, replace the last token with EOS
input_ids[i, last_token_pos] = eos_token_id
return {"input_ids": input_ids, "attention_mask": attention_mask}
def get_config(self) -> dict[str, Any]:
@@ -453,10 +453,11 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
max_action_tokens: int = 32
max_action_tokens: int = 256
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action tokenizer after the dataclass is created.
@@ -488,6 +489,9 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
@@ -520,12 +524,17 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts action tokens to PaliGemma tokens.
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
Args:
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
Returns:
A tuple of (tokens, mask) where:
@@ -568,8 +577,22 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
tokens = torch.cat([
torch.tensor([bos_id], device=action.device),
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
])
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
import logging
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self.max_action_tokens}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[:self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
@@ -659,3 +682,508 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
}
return features
@dataclass
@ProcessorStepRegistry.register(name="action_detokenizer_processor_1")
class ActionDetokenizerProcessorStep1(ActionProcessorStep):
"""
Processor step to detokenize action tokens back to continuous actions.
This step takes tokenized actions (e.g., from model predictions), decodes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the continuous action tensor.
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
during inference to convert predicted tokens back to executable actions.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
action_horizon: int = 1
action_dim: int = 7
relaxed_decoding: bool = False
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action detokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = False
) -> list:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A list representing the decoded actions.
"""
# Use the action tokenizer's decode method
# The FAST tokenizer should have a decode method that converts tokens back to actions
try:
decoded_actions = self.action_tokenizer.decode(
token_ids,
time_horizon=time_horizon,
action_dim=action_dim
)
return decoded_actions
except Exception as e:
if relaxed_decoding:
# If relaxed decoding is enabled, try to decode as much as possible
import logging
logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
try:
# Try to decode with whatever tokens we have
partial_decoded = self.action_tokenizer.decode(
token_ids[:len(token_ids)],
time_horizon=time_horizon,
action_dim=action_dim
)
return partial_decoded
except:
# Return zeros if decoding completely fails
return [[0.0] * action_dim for _ in range(time_horizon)]
else:
raise e
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
Returns:
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
"""
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
# Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
# Re-encode the cleaned text to get raw action tokens
raw_action_tokens = [
self._paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
]
# Convert PaliGemma tokens back to action tokens
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
tokens = [t.flatten().tolist() for t in action_tokens]
# Decode each sample's tokens to continuous actions
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=self.action_horizon,
action_dim=self.action_dim,
relaxed_decoding=self.relaxed_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
# Stack into a batch
result = torch.stack(decoded_actions, dim=0)
# Remove batch dimension if input was single sample
if single_sample:
result = result.squeeze(0)
return result
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
Args:
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
return self.extract_actions(action)
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"action_horizon": self.action_horizon,
"action_dim": self.action_dim,
"relaxed_decoding": self.relaxed_decoding,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect detokenized actions.
This updates the policy features dictionary to indicate that the action
has been detokenized from token IDs back to continuous values.
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the continuous action shape
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the detokenized version
features[PipelineFeatureType.ACTION] = {
"action": PolicyFeature(
type=FeatureType.STATE, # Continuous action
shape=(self.action_horizon, self.action_dim)
)
}
return features
@dataclass
@ProcessorStepRegistry.register(name="action_detokenizer_processor")
class ActionDetokenizerProcessorStep(ActionProcessorStep):
"""
Processor step to detokenize action tokens back to continuous actions.
This step takes tokenized actions (e.g., from model predictions), decodes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the continuous action tensor.
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
during inference to convert predicted tokens back to executable actions.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
action_horizon: int = 1
action_dim: int = 7
relaxed_decoding: bool = False
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action detokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = True
) -> list:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A list representing the decoded actions.
"""
decoded_actions = []
for token in token_ids:
try:
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
if relaxed_decoding:
# expected sequence length
expected_seq_len = time_horizon * action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # tsruncate on the right
# apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
assert decoded_dct_coeff.shape == (
time_horizon,
action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
Returns:
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
"""
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# valid = tokens <= (self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens)
# fast_region = tokens.masked_fill(~valid, 0)
# fast_tokens = self._paligemma_tokens_to_act_tokens(fast_region)
# actions = self.decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.action_horizon, action_dim=self.action_dim, relaxed_decoding=self.relaxed_decoding)[0]
decoded_tokens = [
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
cleaned_tokens = []
for token_seq in decoded_tokens:
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
cleaned_tokens.append(token_seq)
raw_action_tokens = [
torch.tensor(
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
dtype=torch.long,
device=tokens.device,
)
for token_seq in cleaned_tokens
]
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
actions = self.decode_actions_with_fast(
action_tokens,
time_horizon=self.action_horizon,
action_dim=self.action_dim
)
return torch.tensor(actions, device=tokens.device)
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
Args:
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
return self.extract_actions(action)
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"action_horizon": self.action_horizon,
"action_dim": self.action_dim,
"relaxed_decoding": self.relaxed_decoding,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect detokenized actions.
This updates the policy features dictionary to indicate that the action
has been detokenized from token IDs back to continuous values.
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the continuous action shape
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the detokenized version
features[PipelineFeatureType.ACTION] = {
"action": PolicyFeature(
type=FeatureType.STATE, # Continuous action
shape=(self.action_horizon, self.action_dim)
)
}
return features
+1
View File
@@ -173,6 +173,7 @@ def rollout(
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
+9 -1
View File
@@ -62,6 +62,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
postprocessor = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -90,6 +91,10 @@ def update_policy(
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# action = policy.predict_action_chunk(batch)
# if postprocessor is not None:
# action = postprocessor(action)
# breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -151,7 +156,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
accelerator = Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=4, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
@@ -206,6 +211,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
@@ -244,6 +250,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -343,6 +350,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
postprocessor=postprocessor,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we