mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
fix, state is included in language not in flow head
This commit is contained in:
@@ -28,9 +28,6 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
|||||||
# Model architecture
|
# Model architecture
|
||||||
paligemma_variant: str = "gemma_2b"
|
paligemma_variant: str = "gemma_2b"
|
||||||
action_expert_variant: str = "gemma_300m"
|
action_expert_variant: str = "gemma_300m"
|
||||||
discrete_state_input: bool | None = (
|
|
||||||
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
|
|
||||||
)
|
|
||||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||||
|
|
||||||
# Input / output structure
|
# Input / output structure
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@@ -581,7 +582,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
return time.to(dtype=torch.float32, device=device)
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, lang_tokens, lang_masks
|
self, images, img_masks, tokens, masks
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
"""Embed images with SigLIP and language tokens with embedding layer."""
|
||||||
embs = []
|
embs = []
|
||||||
@@ -602,14 +603,14 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
att_masks += [0] * num_img_embs
|
att_masks += [0] * num_img_embs
|
||||||
|
|
||||||
# Process language tokens
|
# Process language tokens
|
||||||
def lang_embed_func(lang_tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
lang_emb_dim = lang_emb.shape[-1]
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
return lang_emb * math.sqrt(lang_emb_dim)
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
pad_masks.append(lang_masks)
|
pad_masks.append(masks)
|
||||||
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs
|
att_masks += [0] * num_lang_embs
|
||||||
@@ -623,8 +624,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks
|
return embs, pad_masks, att_masks
|
||||||
|
|
||||||
def embed_suffix(self, state, noisy_actions, timestep):
|
def embed_suffix(self, noisy_actions, timestep):
|
||||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||||
embs = []
|
embs = []
|
||||||
pad_masks = []
|
pad_masks = []
|
||||||
att_masks = []
|
att_masks = []
|
||||||
@@ -669,9 +670,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(
|
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
|
||||||
) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
noise = self.sample_noise(actions.shape, actions.device)
|
||||||
@@ -683,10 +682,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||||
images, img_masks, lang_tokens, lang_masks
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
)
|
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
@@ -729,15 +726,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
return F.mse_loss(u_t, v_t, reduction="none")
|
return F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||||
def sample_actions(
|
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
|
||||||
) -> Tensor:
|
|
||||||
"""Do a full inference forward and compute the action."""
|
"""Do a full inference forward and compute the action."""
|
||||||
if num_steps is None:
|
if num_steps is None:
|
||||||
num_steps = self.config.num_inference_steps
|
num_steps = self.config.num_inference_steps
|
||||||
|
|
||||||
bsize = state.shape[0]
|
bsize = tokens.shape[0]
|
||||||
device = state.device
|
device = tokens.device
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
# Sample noise with padded dimension as expected by action_in_proj
|
# Sample noise with padded dimension as expected by action_in_proj
|
||||||
@@ -748,9 +743,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
) # Use config max_action_dim for internal processing
|
) # Use config max_action_dim for internal processing
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||||
images, img_masks, lang_tokens, lang_masks
|
|
||||||
)
|
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
@@ -773,7 +766,6 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
while time >= -dt / 2:
|
while time >= -dt / 2:
|
||||||
expanded_time = time.expand(bsize)
|
expanded_time = time.expand(bsize)
|
||||||
v_t = self.denoise_step(
|
v_t = self.denoise_step(
|
||||||
state,
|
|
||||||
prefix_pad_masks,
|
prefix_pad_masks,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
x_t,
|
x_t,
|
||||||
@@ -786,14 +778,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
|||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
self,
|
self,
|
||||||
state,
|
|
||||||
prefix_pad_masks,
|
prefix_pad_masks,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
x_t,
|
x_t,
|
||||||
timestep,
|
timestep,
|
||||||
):
|
):
|
||||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
||||||
|
|
||||||
suffix_len = suffix_pad_masks.shape[1]
|
suffix_len = suffix_pad_masks.shape[1]
|
||||||
batch_size = prefix_pad_masks.shape[0]
|
batch_size = prefix_pad_masks.shape[0]
|
||||||
@@ -1129,7 +1120,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
return images, img_masks
|
return images, img_masks
|
||||||
|
|
||||||
def _tokenize_language(
|
def _tokenize_language_and_state(
|
||||||
self, batch: dict[str, Tensor]
|
self, batch: dict[str, Tensor]
|
||||||
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
||||||
"""Tokenize language input using PaliGemma tokenizer."""
|
"""Tokenize language input using PaliGemma tokenizer."""
|
||||||
@@ -1149,25 +1140,44 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
|||||||
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
||||||
tasks = ["Pick up the object"] * batch_size
|
tasks = ["Pick up the object"] * batch_size
|
||||||
|
|
||||||
# Tokenize with max_length padding to match OpenPI's expected format
|
# Handle discrete state input for PI05 (always the case for pi05)
|
||||||
|
# Get state from batch and discretize it
|
||||||
|
state: Any | None = batch.get(OBS_STATE)
|
||||||
|
if state is None:
|
||||||
|
raise ValueError("Robot state is required for PI05")
|
||||||
|
|
||||||
|
# Prepare state (pad to max_state_dim)
|
||||||
|
state = pad_vector(state, self.config.max_state_dim)
|
||||||
|
|
||||||
|
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||||
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
|
state_np = state.cpu().numpy()
|
||||||
|
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
|
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
|
full_prompts = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||||
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
|
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||||
|
full_prompts.append(full_prompt)
|
||||||
|
|
||||||
|
# Tokenize the full prompts with state
|
||||||
|
# Use the HuggingFace tokenizer properly (not .encode() which doesn't exist on AutoTokenizer)
|
||||||
tokenized = self.tokenizer(
|
tokenized = self.tokenizer(
|
||||||
tasks,
|
full_prompts,
|
||||||
padding="max_length", # Use max_length padding as per OpenPI
|
padding="max_length",
|
||||||
padding_side="right", # from lerobot pi0 `prepare_language`
|
padding_side="right",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_token_len, # Use the max token length from config
|
max_length=self.max_token_len,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
lang_tokens = tokenized["input_ids"].to(device)
|
tokens = tokenized["input_ids"].to(device)
|
||||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||||
|
|
||||||
return lang_tokens, lang_masks
|
return tokens, masks
|
||||||
|
|
||||||
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
|
|
||||||
"""Pad state"""
|
|
||||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
|
||||||
"""Pad action"""
|
"""Pad action"""
|
||||||
@@ -1196,11 +1206,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||||
state = self.prepare_state(batch)
|
|
||||||
|
|
||||||
# Sample actions using the model
|
# Sample actions using the model (no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
# Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
@@ -1216,13 +1225,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
|
||||||
|
|
||||||
state = self.prepare_state(batch)
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user