fix, state is included in language not in flow head

This commit is contained in:
Pepijn
2025-09-17 23:39:00 +02:00
parent 02f52807e6
commit 0f62c180d9
2 changed files with 56 additions and 51 deletions
@@ -28,9 +28,6 @@ class PI05OpenPIConfig(PreTrainedConfig):
# Model architecture # Model architecture
paligemma_variant: str = "gemma_2b" paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m" action_expert_variant: str = "gemma_300m"
discrete_state_input: bool | None = (
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
)
dtype: str = "float32" # Options: "bfloat16", "float32" dtype: str = "float32" # Options: "bfloat16", "float32"
# Input / output structure # Input / output structure
@@ -19,8 +19,9 @@ import logging
import math import math
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Any, Literal
import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn from torch import Tensor, nn
@@ -581,7 +582,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return time.to(dtype=torch.float32, device=device) return time.to(dtype=torch.float32, device=device)
def embed_prefix( def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks self, images, img_masks, tokens, masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer.""" """Embed images with SigLIP and language tokens with embedding layer."""
embs = [] embs = []
@@ -602,14 +603,14 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
att_masks += [0] * num_img_embs att_masks += [0] * num_img_embs
# Process language tokens # Process language tokens
def lang_embed_func(lang_tokens): def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1] lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim) return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb) embs.append(lang_emb)
pad_masks.append(lang_masks) pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1] num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs att_masks += [0] * num_lang_embs
@@ -623,8 +624,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return embs, pad_masks, att_masks return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep): def embed_suffix(self, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" """Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = [] embs = []
pad_masks = [] pad_masks = []
att_masks = [] att_masks = []
@@ -669,9 +670,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return embs, pad_masks, att_masks, adarms_cond return embs, pad_masks, att_masks, adarms_cond
def forward( def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss.""" """Do a full training forward pass and compute the loss."""
if noise is None: if noise is None:
noise = self.sample_noise(actions.shape, actions.device) noise = self.sample_noise(actions.shape, actions.device)
@@ -683,10 +682,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
images, img_masks, lang_tokens, lang_masks suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if ( if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
@@ -729,15 +726,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
return F.mse_loss(u_t, v_t, reduction="none") return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted) @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions( def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
) -> Tensor:
"""Do a full inference forward and compute the action.""" """Do a full inference forward and compute the action."""
if num_steps is None: if num_steps is None:
num_steps = self.config.num_inference_steps num_steps = self.config.num_inference_steps
bsize = state.shape[0] bsize = tokens.shape[0]
device = state.device device = tokens.device
if noise is None: if noise is None:
# Sample noise with padded dimension as expected by action_in_proj # Sample noise with padded dimension as expected by action_in_proj
@@ -748,9 +743,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
) # Use config max_action_dim for internal processing ) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -773,7 +766,6 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
while time >= -dt / 2: while time >= -dt / 2:
expanded_time = time.expand(bsize) expanded_time = time.expand(bsize)
v_t = self.denoise_step( v_t = self.denoise_step(
state,
prefix_pad_masks, prefix_pad_masks,
past_key_values, past_key_values,
x_t, x_t,
@@ -786,14 +778,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
def denoise_step( def denoise_step(
self, self,
state,
prefix_pad_masks, prefix_pad_masks,
past_key_values, past_key_values,
x_t, x_t,
timestep, timestep,
): ):
"""Apply one denoising step of the noise `x_t` at a given timestep.""" """Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1] suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0] batch_size = prefix_pad_masks.shape[0]
@@ -1129,7 +1120,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
return images, img_masks return images, img_masks
def _tokenize_language( def _tokenize_language_and_state(
self, batch: dict[str, Tensor] self, batch: dict[str, Tensor]
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language` ) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
"""Tokenize language input using PaliGemma tokenizer.""" """Tokenize language input using PaliGemma tokenizer."""
@@ -1149,25 +1140,44 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
batch_size = batch[next(iter(batch.keys()))].shape[0] batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = ["Pick up the object"] * batch_size tasks = ["Pick up the object"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format # Handle discrete state input for PI05 (always the case for pi05)
# Get state from batch and discretize it
state: Any | None = batch.get(OBS_STATE)
if state is None:
raise ValueError("Robot state is required for PI05")
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.config.max_state_dim)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize the full prompts with state
# Use the HuggingFace tokenizer properly (not .encode() which doesn't exist on AutoTokenizer)
tokenized = self.tokenizer( tokenized = self.tokenizer(
tasks, full_prompts,
padding="max_length", # Use max_length padding as per OpenPI padding="max_length",
padding_side="right", # from lerobot pi0 `prepare_language` padding_side="right",
truncation=True, truncation=True,
max_length=self.max_token_len, # Use the max token length from config max_length=self.max_token_len,
return_tensors="pt", return_tensors="pt",
add_special_tokens=True,
) )
lang_tokens = tokenized["input_ids"].to(device) tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
return lang_tokens, lang_masks return tokens, masks
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
"""Pad action""" """Pad action"""
@@ -1196,11 +1206,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Prepare inputs # Prepare inputs
images, img_masks = self._preprocess_images(batch) images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch) tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
state = self.prepare_state(batch)
# Sample actions using the model # Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action` # Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1216,13 +1225,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Prepare inputs # Prepare inputs
images, img_masks = self._preprocess_images(batch) images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch) tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
state = self.prepare_state(batch)
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Compute loss # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]