mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
fix, state is included in language not in flow head
This commit is contained in:
@@ -28,9 +28,6 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
discrete_state_input: bool | None = (
|
||||
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
|
||||
)
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
# Input / output structure
|
||||
|
||||
@@ -19,8 +19,9 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
@@ -581,7 +582,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
self, images, img_masks, tokens, masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
||||
embs = []
|
||||
@@ -602,14 +603,14 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
pad_masks.append(masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
@@ -623,8 +624,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
@@ -669,9 +670,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
@@ -683,10 +682,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
@@ -729,15 +726,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
||||
) -> Tensor:
|
||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension as expected by action_in_proj
|
||||
@@ -748,9 +743,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
@@ -773,7 +766,6 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
@@ -786,14 +778,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
@@ -1129,7 +1120,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def _tokenize_language(
|
||||
def _tokenize_language_and_state(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
||||
"""Tokenize language input using PaliGemma tokenizer."""
|
||||
@@ -1149,25 +1140,44 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
# Handle discrete state input for PI05 (always the case for pi05)
|
||||
# Get state from batch and discretize it
|
||||
state: Any | None = batch.get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("Robot state is required for PI05")
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize the full prompts with state
|
||||
# Use the HuggingFace tokenizer properly (not .encode() which doesn't exist on AutoTokenizer)
|
||||
tokenized = self.tokenizer(
|
||||
tasks,
|
||||
padding="max_length", # Use max_length padding as per OpenPI
|
||||
padding_side="right", # from lerobot pi0 `prepare_language`
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=self.max_token_len, # Use the max token length from config
|
||||
max_length=self.max_token_len,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
return state
|
||||
return tokens, masks
|
||||
|
||||
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
||||
"""Pad action"""
|
||||
@@ -1196,11 +1206,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = self.prepare_state(batch)
|
||||
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
|
||||
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1216,13 +1225,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
Reference in New Issue
Block a user