mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 15:19:43 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 995a46b302 | |||
| 23d4846423 | |||
| 7d897daeb2 | |||
| 7556c7fd70 | |||
| 4434c863b4 | |||
| 4b40153c32 | |||
| f0923e5c86 | |||
| 8edd544bbe | |||
| e682ef05f9 | |||
| 9b5ac4387c | |||
| 5781754c30 |
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
dataset = LeRobotDataset(repo_id="lerobot/libero")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
print(batch.keys())
|
||||
|
||||
breakpoint()
|
||||
@@ -2,25 +2,26 @@
|
||||
"repo_id": "local",
|
||||
"vocab_size": 1024,
|
||||
"scale": 10.0,
|
||||
"encoded_dims": "0:15",
|
||||
"encoded_dims": "0:7",
|
||||
"encoded_dim_ranges": [
|
||||
[
|
||||
0,
|
||||
15
|
||||
7
|
||||
]
|
||||
],
|
||||
"total_encoded_dims": 15,
|
||||
"total_encoded_dims": 7,
|
||||
"delta_dims": null,
|
||||
"delta_dim_list": null,
|
||||
"use_delta_transform": false,
|
||||
"state_key": "observation.state",
|
||||
"action_horizon": 50,
|
||||
"num_training_chunks": 4900,
|
||||
"normalization_mode": "QUANTILES",
|
||||
"action_horizon": 10,
|
||||
"num_training_chunks": 25065,
|
||||
"compression_stats": {
|
||||
"compression_ratio": 15.85791309863622,
|
||||
"mean_token_length": 47.295,
|
||||
"p99_token_length": 90.0,
|
||||
"min_token_length": 9.0,
|
||||
"max_token_length": 109.0
|
||||
"compression_ratio": 3.464660463274599,
|
||||
"mean_token_length": 20.204,
|
||||
"p99_token_length": 36.00999999999999,
|
||||
"min_token_length": 5.0,
|
||||
"max_token_length": 38.0
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"action_dim": 15,
|
||||
"action_dim": 7,
|
||||
"auto_map": {
|
||||
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
|
||||
},
|
||||
"min_token": -71,
|
||||
"min_token": -32,
|
||||
"processor_class": "UniversalActionProcessor",
|
||||
"scale": 10.0,
|
||||
"time_horizon": 50,
|
||||
"time_horizon": 10,
|
||||
"vocab_size": 1024
|
||||
}
|
||||
|
||||
+4403
-4023
File diff suppressed because it is too large
Load Diff
@@ -162,7 +162,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
# self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
|
||||
@@ -40,6 +40,8 @@ class PI05Config(PreTrainedConfig):
|
||||
max_action_tokens: int = 32
|
||||
fast_vocab_size: int = 2048
|
||||
|
||||
# FAST-only mode: train with only discrete action token prediction (no flow matching, no subtask)
|
||||
fast_only: bool = False
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
|
||||
@@ -21,8 +21,10 @@ from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from scipy.fftpack import idct
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
@@ -536,18 +538,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
# # FAST action token embedding and prediction head
|
||||
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
# FAST action token embedding and prediction head
|
||||
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
# Apply dtype conversion to FAST layers to match model precision
|
||||
if config.dtype == "bfloat16":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
elif config.dtype == "float32":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
from transformers import AutoTokenizer
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# # Apply dtype conversion to FAST layers to match model precision
|
||||
# if config.dtype == "bfloat16":
|
||||
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
# elif config.dtype == "float32":
|
||||
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
@@ -1213,6 +1220,514 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
|
||||
}
|
||||
|
||||
def embed_prefix_fast(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
|
||||
"""Embed images, language tokens, and FAST action tokens for FAST-only mode.
|
||||
|
||||
This is a simplified version of embed_prefix without subtask tokens.
|
||||
Attention pattern:
|
||||
- Images + Language: bidirectional among themselves
|
||||
- FAST: attend to images + language, causal among themselves
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
tokens: Language instruction tokens
|
||||
masks: Attention masks for tokens
|
||||
fast_action_tokens: FAST action tokens (discrete token IDs)
|
||||
fast_action_masks: Padding masks for FAST action tokens
|
||||
|
||||
Returns:
|
||||
embs: Concatenated embeddings [images, tokens, fast_action_tokens]
|
||||
pad_masks: Padding masks
|
||||
att_masks: 2D attention mask
|
||||
total_T_images: Total number of image tokens
|
||||
num_fast_embs: Number of FAST action token embeddings
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_mask_segments = []
|
||||
total_T_images = 0
|
||||
num_fast_embs = 0
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_mask_segments.append(('image', num_img_embs))
|
||||
total_T_images += num_img_embs
|
||||
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_mask_segments.append(('language', num_lang_embs))
|
||||
|
||||
# Process FAST action tokens (discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
|
||||
if fast_action_masks is not None:
|
||||
fast_pad_mask = fast_action_masks
|
||||
else:
|
||||
bsize = fast_action_tokens.shape[0]
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
|
||||
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
pad_masks.append(fast_pad_mask)
|
||||
att_mask_segments.append(('fast', num_fast_embs))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
|
||||
# Create custom 2D attention mask for FAST-only mode:
|
||||
# - Images + Language: bidirectional among themselves
|
||||
# - FAST: attend to images + language, causal among themselves
|
||||
att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize)
|
||||
|
||||
return embs, pad_masks, att_masks, total_T_images, num_fast_embs
|
||||
|
||||
def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize):
|
||||
"""Create custom 2D attention mask for FAST-only mode.
|
||||
|
||||
Attention rules:
|
||||
- Images + Language: bidirectional among themselves
|
||||
- FAST: attend to images + language, causal among themselves
|
||||
"""
|
||||
total_len = sum(length for _, length in att_mask_segments)
|
||||
device = pad_masks.device
|
||||
|
||||
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||
# Images and Language can attend to each other bidirectionally
|
||||
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# FAST tokens attend to images + language
|
||||
elif query_type == 'fast' and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# FAST tokens attend causally to themselves
|
||||
elif query_type == 'fast' and key_type == 'fast':
|
||||
fast_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# Apply padding masks
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
|
||||
return att_2d_masks
|
||||
|
||||
def forward_fast_only(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
fast_action_tokens,
|
||||
fast_action_masks,
|
||||
) -> dict:
|
||||
"""Forward pass for FAST-only mode (no flow matching, no subtask).
|
||||
|
||||
This implements the Pi0FAST training objective: predict next action token
|
||||
using cross-entropy loss.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
tokens: Language instruction tokens
|
||||
masks: Attention masks for tokens
|
||||
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
|
||||
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
|
||||
|
||||
Returns:
|
||||
Dictionary with 'fast_loss' and 'loss' keys
|
||||
"""
|
||||
if fast_action_tokens is None or fast_action_masks is None:
|
||||
raise ValueError("fast_action_tokens and fast_action_masks are required for FAST-only mode")
|
||||
|
||||
# Embed prefix with FAST tokens
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, num_fast_embs = self.embed_prefix_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=fast_action_tokens,
|
||||
fast_action_masks=fast_action_masks
|
||||
)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# For next-token prediction, we input tokens [0:T-1] to predict tokens [1:T]
|
||||
# So we remove the last token from input
|
||||
# input_embs = prefix_embs[:, :-1]
|
||||
# input_pad_masks = prefix_pad_masks[:, :-1]
|
||||
# input_att_masks = prefix_att_masks[:, :-1, :-1]
|
||||
|
||||
input_embs = prefix_embs
|
||||
input_pad_masks = prefix_pad_masks
|
||||
input_att_masks = prefix_att_masks
|
||||
|
||||
position_ids = torch.cumsum(input_pad_masks, dim=1) - 1
|
||||
att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype)
|
||||
|
||||
# Forward pass through paligemma (language model only, no action expert)
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[input_embs, None], # No suffix/action expert
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Get logits for FAST action tokens using the FAST LM head
|
||||
# We only compute logits for the positions that predict FAST tokens
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# The FAST tokens start at position (total_T_images + num_lang_tokens)
|
||||
# For next-token prediction:
|
||||
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
|
||||
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
|
||||
|
||||
# Targets are the FAST action tokens
|
||||
fast_targets = fast_action_tokens # (B, num_fast_embs)
|
||||
T_lang = masks.shape[1]
|
||||
fast_start = total_T_images + T_lang
|
||||
|
||||
|
||||
# Extract logits for FAST token prediction
|
||||
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
|
||||
# fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
|
||||
fast_hidden = prefix_out[:, -fast_targets.shape[1]:, :]
|
||||
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
|
||||
|
||||
# Shift left for next-step prediction and shift target
|
||||
# logits[:, i] predicts targets[:, i+1]
|
||||
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # Shift logits left
|
||||
fast_targets = fast_targets[:, 1:] # Shift targets right
|
||||
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
|
||||
|
||||
# from transformers import AutoTokenizer
|
||||
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# trust_remote_code=True,
|
||||
# add_eos_token=True,
|
||||
# add_bos_token=False
|
||||
# )
|
||||
# # remove
|
||||
# decoded_tokens = [
|
||||
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_targets
|
||||
# ]
|
||||
# corrected_tokens = [
|
||||
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_logits_for_pred.argmax(dim=-1)
|
||||
# ]
|
||||
# breakpoint()
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
|
||||
fast_targets_flat = fast_targets.reshape(-1)
|
||||
|
||||
fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat)
|
||||
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape)
|
||||
|
||||
# Apply mask and compute mean loss
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
|
||||
return {
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fast_loss,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions_fast(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
max_decoding_steps=None,
|
||||
temperature=0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Inefficient but safe autoregressive decoding for FAST tokens.
|
||||
Matches the pattern of _generate_subtask_tokens.
|
||||
"""
|
||||
if max_decoding_steps is None:
|
||||
max_decoding_steps = self.config.max_action_tokens
|
||||
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# add bos token after tokens
|
||||
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
|
||||
tokens = torch.cat([tokens, bos_token], dim=1)
|
||||
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None
|
||||
)
|
||||
|
||||
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
|
||||
|
||||
# 2. Decoding Loop (Re-computes full sequence every step)
|
||||
for t in range(max_decoding_steps):
|
||||
# Always re-calculate position IDs from the current pad mask (matches training)
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
# Full forward pass (No KV Cache)
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Predict next token from the very last sequence position
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
generated_action_tokens[:, t] = next_token.squeeze(-1)
|
||||
|
||||
# 3. Update Sequence for next iteration (unless it's the last step)
|
||||
if t < max_decoding_steps - 1:
|
||||
# Embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Append to embeddings
|
||||
prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1)
|
||||
|
||||
# Update padding mask (New token is always valid/1)
|
||||
prefix_pad_masks = torch.cat([
|
||||
prefix_pad_masks,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1)
|
||||
|
||||
# Update 2D attention mask: Grow the matrix
|
||||
old_len = prefix_att_masks.shape[1]
|
||||
new_len = old_len + 1
|
||||
new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device)
|
||||
new_att_masks[:, :old_len, :old_len] = prefix_att_masks
|
||||
# New token attends to all non-padding tokens in the updated sequence
|
||||
new_att_masks[:, -1, :] = prefix_pad_masks
|
||||
prefix_att_masks = new_att_masks
|
||||
return generated_action_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions_fast_kv_cache(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
max_decoding_steps=None,
|
||||
temperature=0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient autoregressive decoding for FAST tokens using KV-caching.
|
||||
Only computes the prefix once, then incrementally generates tokens.
|
||||
"""
|
||||
if max_decoding_steps is None:
|
||||
max_decoding_steps = self.config.max_action_tokens
|
||||
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
bos_token = torch.full((bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device)
|
||||
tokens = torch.cat([tokens, bos_token], dim=1)
|
||||
masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None
|
||||
)
|
||||
|
||||
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
|
||||
|
||||
# 2. Initial forward pass to populate KV cache
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
# First forward pass with full prefix (caching enabled)
|
||||
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# # Get BOS token and add it as the first token in action sequence
|
||||
# bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
|
||||
|
||||
# # Embed BOS token
|
||||
# bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
|
||||
# bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
|
||||
# if prefix_embs.dtype == torch.bfloat16:
|
||||
# bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Track current sequence length for position IDs and maintain the padding mask
|
||||
current_seq_len = prefix_embs.shape[1]
|
||||
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
|
||||
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
|
||||
|
||||
# # Update padding mask for BOS token: it's always valid
|
||||
# current_pad_mask = torch.cat([
|
||||
# current_pad_mask,
|
||||
# torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
# ], dim=1) # (B, seq_len+1)
|
||||
|
||||
# # Position ID for BOS token (continues from where prefix ended)
|
||||
# bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
|
||||
# # Attention mask for BOS token: attends to all valid prefix positions
|
||||
# bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
# bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
|
||||
|
||||
# # Forward pass with BOS token (reusing cached KVs from prefix)
|
||||
# (bos_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
# attention_mask=bos_att_4d,
|
||||
# position_ids=bos_position_id,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=[bos_token_emb, None],
|
||||
# use_cache=True,
|
||||
# adarms_cond=[None, None],
|
||||
# )
|
||||
|
||||
# # Update sequence length to account for BOS token
|
||||
# current_seq_len += 1
|
||||
|
||||
# Predict first action token from BOS token output
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
generated_action_tokens[:, 0] = next_token.squeeze(-1)
|
||||
|
||||
# 3. Incremental Decoding Loop (using KV cache)
|
||||
for t in range(1, max_decoding_steps):
|
||||
# Embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Update padding mask: new generated token is always valid
|
||||
current_pad_mask = torch.cat([
|
||||
current_pad_mask,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1) # (B, seq_len+1)
|
||||
|
||||
# Position ID for the new token (continues from where we left off)
|
||||
new_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
|
||||
# For KV-cache: attention mask for the new token should only attend to valid positions
|
||||
# Shape: (B, 1, past_len+1) where the new token attends to valid prefix + all generated tokens
|
||||
new_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
att_4d_incremental = self._prepare_attention_masks_4d(new_att_mask_2d, dtype=next_token_emb.dtype)
|
||||
|
||||
# Forward pass with only the new token embedding (reusing cached KVs)
|
||||
(new_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_4d_incremental,
|
||||
position_ids=new_position_id,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[next_token_emb, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Predict next token
|
||||
last_logits = lm_head(new_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
generated_action_tokens[:, t] = next_token.squeeze(-1)
|
||||
|
||||
# Update sequence length
|
||||
current_seq_len += 1
|
||||
|
||||
return generated_action_tokens
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_subtask_tokens(
|
||||
@@ -1480,6 +1995,34 @@ class PI05Policy(PreTrainedPolicy):
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
# Load FAST tokenizer for action detokenization (only if fast_only mode)
|
||||
self.action_tokenizer = None
|
||||
self._paligemma_tokenizer = None
|
||||
self._fast_skip_tokens = 128
|
||||
|
||||
if config.fast_only:
|
||||
try:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
# Load FAST tokenizer
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
"jadechoghari/fast-libero-tokenizer-mean-std",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load PaliGemma tokenizer for token conversion
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
|
||||
logging.info("Loaded FAST tokenizer for action detokenization")
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
|
||||
logging.warning("Action tokens will be returned without detokenization")
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -1556,7 +2099,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -1655,6 +2197,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if key == "model.paligemma_with_expert.paligemma.lm_head.weight" or key == "paligemma_with_expert.paligemma.lm_head.weight":
|
||||
fixed_state_dict["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
@@ -1755,6 +2300,173 @@ class PI05Policy(PreTrainedPolicy):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
|
||||
Args:
|
||||
tokens: PaliGemma token IDs
|
||||
|
||||
Returns:
|
||||
Action token IDs
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A numpy array representing the decoded actions.
|
||||
"""
|
||||
decoded_actions = []
|
||||
|
||||
for token in token_ids:
|
||||
try:
|
||||
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
|
||||
|
||||
if relaxed_decoding:
|
||||
# expected sequence length
|
||||
expected_seq_len = time_horizon * action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
|
||||
# apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right
|
||||
|
||||
# apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
time_horizon,
|
||||
action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Error decoding tokens: {e}")
|
||||
logging.warning(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
|
||||
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
|
||||
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
This method converts predicted action tokens from the model back to continuous action values
|
||||
using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token
|
||||
space, then decodes the action tokens to continuous values using DCT decoding.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
|
||||
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
|
||||
raise ValueError(
|
||||
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
|
||||
)
|
||||
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# Convert token IDs to token strings
|
||||
decoded_tokens = [
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
# Get the token sequence for "Action: " to remove it
|
||||
action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False)
|
||||
action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids)
|
||||
action_prefix_len = len(action_prefix_tokens)
|
||||
|
||||
# Clean tokens by removing everything after the first "|" (end-of-action marker)
|
||||
# and removing all occurrences of "Action: " token sequence
|
||||
# assert that beginning contain "Action: "
|
||||
for token_seq in decoded_tokens:
|
||||
assert (
|
||||
len(token_seq) >= 2
|
||||
and token_seq[0] == "Action"
|
||||
and token_seq[1] == ":"
|
||||
), f"Token sequence does not start with ['Action', ':']: {token_seq}"
|
||||
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
# Remove everything after "|"
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
|
||||
# Remove all occurrences of "Action: " token sequence
|
||||
i = 0
|
||||
while i <= len(token_seq) - action_prefix_len:
|
||||
if token_seq[i:i+action_prefix_len] == action_prefix_tokens:
|
||||
# Found a match, remove it
|
||||
token_seq = token_seq[:i] + token_seq[i+action_prefix_len:]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
cleaned_tokens.append(token_seq)
|
||||
|
||||
# Convert token strings back to IDs
|
||||
raw_action_tokens = [
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
dtype=torch.long,
|
||||
device=tokens.device,
|
||||
)
|
||||
for token_seq in cleaned_tokens
|
||||
]
|
||||
|
||||
# Convert PaliGemma tokens to action tokens
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
|
||||
# Decode action tokens to continuous actions
|
||||
actions = self.decode_actions_with_fast(
|
||||
action_tokens,
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
|
||||
# Convert to tensor and return
|
||||
actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
actions_tensor = actions_tensor.squeeze(0)
|
||||
|
||||
return actions_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -1780,6 +2492,35 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
|
||||
# FAST-only mode: use autoregressive decoding
|
||||
if self.config.fast_only:
|
||||
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Get optional parameters
|
||||
temperature = kwargs.get("temperature", 0.0)
|
||||
max_decoding_steps = 256
|
||||
|
||||
# Sample action tokens autoregressively
|
||||
action_tokens = self.model.sample_actions_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
max_decoding_steps=max_decoding_steps,
|
||||
temperature=temperature,
|
||||
)
|
||||
# Detokenize action tokens to continuous actions
|
||||
action_horizon = self.config.n_action_steps
|
||||
action_dim = 7
|
||||
|
||||
continuous_actions = self.detokenize_actions(
|
||||
action_tokens,
|
||||
action_horizon=action_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
# Full mode: use flow matching with optional subtask generation
|
||||
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
|
||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
||||
@@ -1802,24 +2543,39 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
|
||||
# Get FAST action tokens from batch
|
||||
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
|
||||
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
|
||||
|
||||
# FAST-only mode: only use discrete action token prediction
|
||||
if self.config.fast_only:
|
||||
# Use full language tokens (no separation into high_level_task and subtask)
|
||||
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
if fast_action_tokens is None or fast_action_masks is None:
|
||||
raise ValueError("FAST-only mode requires action.tokens and action.token_mask in the batch")
|
||||
|
||||
loss_dict = self.model.forward_fast_only(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=fast_action_tokens,
|
||||
fast_action_masks=fast_action_masks
|
||||
)
|
||||
|
||||
loss = loss_dict["loss"]
|
||||
detailed_loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"fast_loss": loss_dict["fast_loss"].item(),
|
||||
}
|
||||
return loss, detailed_loss_dict
|
||||
|
||||
# Full mode: flow matching + subtask + FAST
|
||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Decode and print ground truth subtask tokens during training
|
||||
# if self.tokenizer is not None and self.training:
|
||||
# bsize = subtask_tokens.shape[0]
|
||||
# for i in range(bsize):
|
||||
# # Remove padding tokens (0) and special tokens
|
||||
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
||||
# # if len(valid_tokens) > 0:
|
||||
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
||||
|
||||
# Get FAST action tokens from batch
|
||||
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
|
||||
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
||||
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
||||
|
||||
@@ -101,8 +101,8 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
|
||||
else:
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade
|
||||
|
||||
low_level_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
|
||||
|
||||
@@ -2,7 +2,7 @@ export CUDA_LAUNCH_BLOCKING=1
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit1 \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit2 \
|
||||
--job_name=pi0_training \
|
||||
--policy.repo_id=jade_choghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
@@ -14,9 +14,10 @@ lerobot-train \
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=4 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05hi-training \
|
||||
--policy.fast_only=true \
|
||||
# --wandb.enable=true \
|
||||
# --wandb.disable_artifact=true \
|
||||
# --wandb.project=pi05hi-training \
|
||||
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
|
||||
@@ -1,18 +1,13 @@
|
||||
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local\
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=32 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "local" \
|
||||
--root "/fsx/jade_choghari/outputs/collect-data-pgen" \
|
||||
--action_horizon 16 \
|
||||
--encoded_dims "0:15" \
|
||||
--action_horizon 50 \
|
||||
--root /fsx/jade_choghari/data/libero \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:7" \
|
||||
--vocab_size 1024 \
|
||||
--scale 10.0 \
|
||||
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"
|
||||
--push_to_hub \
|
||||
--hub_repo_id jadechoghari/fast-libero-tokenizer-quantiles \
|
||||
--normalization_mode QUANTILES \
|
||||
|
||||
|
||||
# python train_fast_tokenizer.py --repo_id my_dataset
|
||||
@@ -15,6 +15,8 @@ from pathlib import Path
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
@@ -39,6 +41,64 @@ def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: li
|
||||
return delta_actions
|
||||
|
||||
|
||||
def apply_normalization(
|
||||
data: np.ndarray,
|
||||
stats: dict[str, np.ndarray],
|
||||
mode: NormalizationMode,
|
||||
eps: float = 1e-8,
|
||||
) -> np.ndarray:
|
||||
"""Apply normalization to data based on the specified mode.
|
||||
|
||||
Args:
|
||||
data: Data to normalize [N, H, D] or [D]
|
||||
stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90)
|
||||
mode: Normalization mode to apply
|
||||
eps: Small epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
Normalized data with the same shape as input
|
||||
"""
|
||||
if mode == NormalizationMode.IDENTITY:
|
||||
return data
|
||||
|
||||
if mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean")
|
||||
std = stats.get("std")
|
||||
if mean is None or std is None:
|
||||
raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats")
|
||||
return (data - mean) / np.maximum(std, eps)
|
||||
|
||||
if mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min")
|
||||
max_val = stats.get("max")
|
||||
if min_val is None or max_val is None:
|
||||
raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats")
|
||||
denom = np.maximum(max_val - min_val, eps)
|
||||
return 2.0 * (data - min_val) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILES:
|
||||
q01 = stats.get("q01")
|
||||
q99 = stats.get("q99")
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
|
||||
denom = np.maximum(q99 - q01, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q01, q99)
|
||||
return 2.0 * (clipped - q01) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10")
|
||||
q90 = stats.get("q90")
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
|
||||
denom = np.maximum(q90 - q10, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q10, q90)
|
||||
return 2.0 * (clipped - q10) / denom - 1.0
|
||||
|
||||
raise ValueError(f"Unsupported normalization mode: {mode}")
|
||||
|
||||
|
||||
def process_episode(args):
|
||||
"""Process single episode and return action chunks."""
|
||||
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
|
||||
@@ -237,9 +297,13 @@ def main(
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = "observation.state",
|
||||
normalization_mode: str = "QUANTILES",
|
||||
vocab_size: int = 1024,
|
||||
scale: float = 10.0,
|
||||
output_dir: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hub_repo_id: str | None = None,
|
||||
hub_private: bool = False,
|
||||
):
|
||||
"""
|
||||
Train FAST tokenizer for action encoding.
|
||||
@@ -254,15 +318,29 @@ def main(
|
||||
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
|
||||
state_key: Dataset key for state observations (default: "observation.state")
|
||||
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
vocab_size: FAST vocabulary size (BPE vocab size)
|
||||
scale: DCT scaling factor (default: 10.0)
|
||||
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
|
||||
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
||||
hub_private: Whether to create a private repository on the Hub
|
||||
"""
|
||||
# Load dataset
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
# Parse normalization mode
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid normalization_mode: {normalization_mode}. "
|
||||
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
|
||||
)
|
||||
print(f"Normalization mode: {norm_mode.value}")
|
||||
|
||||
# Parse encoded dimensions
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(','):
|
||||
@@ -317,13 +395,12 @@ def main(
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
|
||||
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
|
||||
|
||||
# Apply normalization to encoded dimensions only
|
||||
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
|
||||
# This clips outliers and provides consistent [-1, 1] range for DCT compression
|
||||
# Apply normalization to encoded dimensions
|
||||
print(f"\nBefore normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
|
||||
|
||||
# Get normalization stats from dataset
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and "action" in norm_stats:
|
||||
action_stats = norm_stats["action"]
|
||||
@@ -334,19 +411,31 @@ def main(
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
|
||||
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
|
||||
if "q01" in action_stats and "q99" in action_stats:
|
||||
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
|
||||
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
|
||||
# Extract stats for encoded dimensions only
|
||||
encoded_stats = {}
|
||||
for stat_name, stat_values in action_stats.items():
|
||||
if isinstance(stat_values, (list, np.ndarray)):
|
||||
stat_array = np.array(stat_values)
|
||||
if len(stat_array) > max(encoded_dim_indices):
|
||||
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
||||
|
||||
if encoded_stats:
|
||||
print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):")
|
||||
for stat_name, stat_values in encoded_stats.items():
|
||||
print(f" {stat_name}: shape={stat_values.shape}, "
|
||||
f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]")
|
||||
|
||||
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
|
||||
for i, dim_idx in enumerate(encoded_dim_indices):
|
||||
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
|
||||
|
||||
# Clip to quantile range and normalize to [-1, 1]
|
||||
encoded_chunks = np.clip(encoded_chunks, q01, q99)
|
||||
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
|
||||
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
|
||||
# Apply normalization based on mode
|
||||
try:
|
||||
encoded_chunks = apply_normalization(
|
||||
encoded_chunks,
|
||||
encoded_stats,
|
||||
norm_mode,
|
||||
eps=1e-8
|
||||
)
|
||||
print(f"\nApplied {norm_mode.value} normalization")
|
||||
except ValueError as e:
|
||||
print(f"Warning: {e}. Using raw actions without normalization.")
|
||||
|
||||
print(f"\nAfter normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
@@ -358,9 +447,9 @@ def main(
|
||||
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
|
||||
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
|
||||
else:
|
||||
print("Warning: q01/q99 stats not found, using raw actions")
|
||||
print("Warning: Could not extract stats for encoded dimensions, using raw actions")
|
||||
else:
|
||||
print("Warning: No normalization stats found, using raw actions")
|
||||
print("Warning: No normalization stats found in dataset, using raw actions")
|
||||
|
||||
print(f"Encoded chunks shape: {encoded_chunks.shape}")
|
||||
|
||||
@@ -394,6 +483,7 @@ def main(
|
||||
'delta_dim_list': delta_dim_list,
|
||||
'use_delta_transform': use_delta_transform,
|
||||
'state_key': state_key,
|
||||
'normalization_mode': norm_mode.value,
|
||||
'action_horizon': action_horizon,
|
||||
'num_training_chunks': len(encoded_chunks),
|
||||
'compression_stats': compression_stats,
|
||||
@@ -402,8 +492,41 @@ def main(
|
||||
with open(output_path / "metadata.json", 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
print(f"\n✅ Saved FAST tokenizer to {output_path}")
|
||||
print(f"\nSaved FAST tokenizer to {output_path}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
# Push to Hugging Face Hub if requested
|
||||
if push_to_hub:
|
||||
# Determine the hub repository ID
|
||||
if hub_repo_id is None:
|
||||
hub_repo_id = output_path.name
|
||||
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
|
||||
|
||||
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
|
||||
print(f" Private: {hub_private}")
|
||||
|
||||
try:
|
||||
# Use the tokenizer's push_to_hub method
|
||||
tokenizer.push_to_hub(
|
||||
repo_id=hub_repo_id,
|
||||
private=hub_private,
|
||||
commit_message=f"Upload FAST tokenizer trained on {repo_id}"
|
||||
)
|
||||
|
||||
# Also upload the metadata.json file separately
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path / "metadata.json"),
|
||||
path_in_repo="metadata.json",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Upload tokenizer metadata"
|
||||
)
|
||||
|
||||
print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Error pushing to hub: {e}")
|
||||
print(" Make sure you're logged in with `huggingface-cli login`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Executable
+28
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
# FSDP training script for PI05 with aggressive memory optimization
|
||||
# Use this for large models that OOM with standard DDP
|
||||
|
||||
accelerate launch --config_file /admin/home/jade_choghari/lerobot/fsdp_config.yaml \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fsdp \
|
||||
--job_name=libero_training_fsdp \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero-fsdp \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=100000 \
|
||||
--save_freq=10 \
|
||||
--batch_size=8 \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=2000 \
|
||||
--policy.scheduler_decay_steps=60000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05-libero-training-fsdp
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
export CUDA_LAUNCH_BLOCKING=1
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.path=/fsx/jade_choghari/models/pi05-base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=100000 \
|
||||
--save_freq=20000 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=1000 \
|
||||
--policy.scheduler_decay_steps=30000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--rename_map='{
|
||||
"observation.images.image1": "observation.images.base_0_rgb",
|
||||
"observation.images.image2": "observation.images.left_wrist_0_rgb",
|
||||
}' \
|
||||
--policy.empty_cameras=1 \
|
||||
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=pi05-train
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --mem=256G
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=/fsx/jade_choghari/logs/%x-%j.out
|
||||
#SBATCH --error=/fsx/jade_choghari/logs/%x-%j.err
|
||||
|
||||
srun \
|
||||
--container-image=/fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh \
|
||||
--container-mounts=/fsx/jade_choghari \
|
||||
--container-workdir=$HOME/lerobot \
|
||||
bash /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/train_multi.sh
|
||||
@@ -1,23 +1,36 @@
|
||||
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
|
||||
accelerate launch --multi_gpu --num_processes=2 \
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# Source YOUR Miniforge conda (mounted from FSX)
|
||||
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
|
||||
|
||||
conda activate lerobot
|
||||
accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.path=/fsx/jade_choghari/models/pi05-base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--steps=100000 \
|
||||
--save_freq=20000 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=4000 \
|
||||
--policy.scheduler_decay_steps=100000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--batch_size=1 \
|
||||
--policy.device=cpu
|
||||
# --wandb.enable=true \
|
||||
# --wandb.disable_artifact=true \
|
||||
# --wandb.project=pi05hi-training \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.max_action_tokens=256 \
|
||||
--rename_map='{
|
||||
"observation.images.image1": "observation.images.base_0_rgb",
|
||||
"observation.images.image2": "observation.images.left_wrist_0_rgb",
|
||||
}' \
|
||||
--policy.empty_cameras=1 \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05-libero-training \
|
||||
|
||||
@@ -29,7 +29,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.fft import idct
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
@@ -223,7 +225,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
raise ValueError("Task cannot be None")
|
||||
|
||||
# Tokenize the task (this will create CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
@@ -352,7 +353,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
else:
|
||||
# If at max length, replace the last token with EOS
|
||||
input_ids[i, last_token_pos] = eos_token_id
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
@@ -453,10 +453,11 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
max_action_tokens: int = 32
|
||||
max_action_tokens: int = 256
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action tokenizer after the dataclass is created.
|
||||
@@ -488,6 +489,9 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -520,12 +524,17 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts action tokens to PaliGemma tokens.
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
Args:
|
||||
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, mask) where:
|
||||
@@ -568,8 +577,22 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# add bos
|
||||
tokens = torch.cat([
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
])
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
import logging
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self.max_action_tokens}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[:self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
@@ -659,3 +682,508 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor_1")
|
||||
class ActionDetokenizerProcessorStep1(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
# Use the action tokenizer's decode method
|
||||
# The FAST tokenizer should have a decode method that converts tokens back to actions
|
||||
try:
|
||||
decoded_actions = self.action_tokenizer.decode(
|
||||
token_ids,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return decoded_actions
|
||||
except Exception as e:
|
||||
if relaxed_decoding:
|
||||
# If relaxed decoding is enabled, try to decode as much as possible
|
||||
import logging
|
||||
logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
|
||||
try:
|
||||
# Try to decode with whatever tokens we have
|
||||
partial_decoded = self.action_tokenizer.decode(
|
||||
token_ids[:len(token_ids)],
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return partial_decoded
|
||||
except:
|
||||
# Return zeros if decoding completely fails
|
||||
return [[0.0] * action_dim for _ in range(time_horizon)]
|
||||
else:
|
||||
raise e
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
# Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
|
||||
# Re-encode the cleaned text to get raw action tokens
|
||||
raw_action_tokens = [
|
||||
self._paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
]
|
||||
|
||||
# Convert PaliGemma tokens back to action tokens
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
tokens = [t.flatten().tolist() for t in action_tokens]
|
||||
# Decode each sample's tokens to continuous actions
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim,
|
||||
relaxed_decoding=self.relaxed_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
# Stack into a batch
|
||||
result = torch.stack(decoded_actions, dim=0)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
result = result.squeeze(0)
|
||||
|
||||
return result
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor")
|
||||
class ActionDetokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = True
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
decoded_actions = []
|
||||
|
||||
for token in token_ids:
|
||||
try:
|
||||
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
|
||||
|
||||
if relaxed_decoding:
|
||||
# expected sequence length
|
||||
expected_seq_len = time_horizon * action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
|
||||
# apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # tsruncate on the right
|
||||
|
||||
# apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
time_horizon,
|
||||
action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
|
||||
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
|
||||
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# valid = tokens <= (self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens)
|
||||
# fast_region = tokens.masked_fill(~valid, 0)
|
||||
# fast_tokens = self._paligemma_tokens_to_act_tokens(fast_region)
|
||||
# actions = self.decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.action_horizon, action_dim=self.action_dim, relaxed_decoding=self.relaxed_decoding)[0]
|
||||
decoded_tokens = [
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
cleaned_tokens.append(token_seq)
|
||||
raw_action_tokens = [
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
dtype=torch.long,
|
||||
device=tokens.device,
|
||||
)
|
||||
for token_seq in cleaned_tokens
|
||||
]
|
||||
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
actions = self.decode_actions_with_fast(
|
||||
action_tokens,
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim
|
||||
)
|
||||
|
||||
return torch.tensor(actions, device=tokens.device)
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
@@ -173,6 +173,7 @@ def rollout(
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
@@ -62,6 +62,7 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
postprocessor = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -90,6 +91,10 @@ def update_policy(
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# action = policy.predict_action_chunk(batch)
|
||||
# if postprocessor is not None:
|
||||
# action = postprocessor(action)
|
||||
# breakpoint()
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -151,7 +156,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=4, kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
init_logging(accelerator=accelerator)
|
||||
|
||||
@@ -206,6 +211,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -244,6 +250,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -343,6 +350,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
|
||||
Reference in New Issue
Block a user