fix, state is included in language not in flow head

This commit is contained in:
Pepijn
2025-09-17 23:39:00 +02:00
parent 02f52807e6
commit 0f62c180d9
2 changed files with 56 additions and 51 deletions
@@ -28,9 +28,6 @@ class PI05OpenPIConfig(PreTrainedConfig):
# Model architecture
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
discrete_state_input: bool | None = (
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
)
dtype: str = "float32" # Options: "bfloat16", "float32"
# Input / output structure
@@ -19,8 +19,9 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import Literal
from typing import Any, Literal
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
@@ -581,7 +582,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
self, images, img_masks, tokens, masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer."""
embs = []
@@ -602,14 +603,14 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
@@ -623,8 +624,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
@@ -669,9 +670,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return embs, pad_masks, att_masks, adarms_cond
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss."""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
@@ -683,10 +682,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
@@ -729,15 +726,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
) -> Tensor:
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
bsize = state.shape[0]
device = state.device
bsize = tokens.shape[0]
device = tokens.device
if noise is None:
# Sample noise with padded dimension as expected by action_in_proj
@@ -748,9 +743,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -773,7 +766,6 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
@@ -786,14 +778,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
@@ -1129,7 +1120,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
return images, img_masks
def _tokenize_language(
def _tokenize_language_and_state(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
"""Tokenize language input using PaliGemma tokenizer."""
@@ -1149,25 +1140,44 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = ["Pick up the object"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
# Handle discrete state input for PI05 (always the case for pi05)
# Get state from batch and discretize it
state: Any | None = batch.get(OBS_STATE)
if state is None:
raise ValueError("Robot state is required for PI05")
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.config.max_state_dim)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize the full prompts with state
# Use the HuggingFace tokenizer properly (not .encode() which doesn't exist on AutoTokenizer)
tokenized = self.tokenizer(
tasks,
padding="max_length", # Use max_length padding as per OpenPI
padding_side="right", # from lerobot pi0 `prepare_language`
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=self.max_token_len, # Use the max token length from config
max_length=self.max_token_len,
return_tensors="pt",
add_special_tokens=True,
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
tokens = tokenized["input_ids"].to(device)
masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
return lang_tokens, lang_masks
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
return tokens, masks
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
"""Pad action"""
@@ -1196,11 +1206,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch)
state = self.prepare_state(batch)
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1216,13 +1225,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch)
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
state = self.prepare_state(batch)
actions = self.prepare_action(batch)
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]