diff --git a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py b/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py index 87eea97cd..8f0ca3fd7 100644 --- a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py @@ -28,9 +28,6 @@ class PI05OpenPIConfig(PreTrainedConfig): # Model architecture paligemma_variant: str = "gemma_2b" action_expert_variant: str = "gemma_300m" - discrete_state_input: bool | None = ( - True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__` - ) dtype: str = "float32" # Options: "bfloat16", "float32" # Input / output structure diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index b52d70143..3ae64dd4c 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -19,8 +19,9 @@ import logging import math from collections import deque from pathlib import Path -from typing import Literal +from typing import Any, Literal +import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn @@ -581,7 +582,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return time.to(dtype=torch.float32, device=device) def embed_prefix( - self, images, img_masks, lang_tokens, lang_masks + self, images, img_masks, tokens, masks ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Embed images with SigLIP and language tokens with embedding layer.""" embs = [] @@ -602,14 +603,14 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ att_masks += [0] * num_img_embs # Process language tokens - def lang_embed_func(lang_tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb_dim = lang_emb.shape[-1] return lang_emb * math.sqrt(lang_emb_dim) - lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) - pad_masks.append(lang_masks) + pad_masks.append(masks) num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs @@ -623,8 +624,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return embs, pad_masks, att_masks - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + def embed_suffix(self, noisy_actions, timestep): + """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" embs = [] pad_masks = [] att_masks = [] @@ -669,9 +670,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return embs, pad_masks, att_masks, adarms_cond - def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None - ) -> Tensor: + def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -683,10 +682,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks - ) - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype @@ -729,15 +726,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return F.mse_loss(u_t, v_t, reduction="none") @torch.no_grad() # see openpi `sample_actions` (slightly adapted) - def sample_actions( - self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None - ) -> Tensor: + def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: num_steps = self.config.num_inference_steps - bsize = state.shape[0] - device = state.device + bsize = tokens.shape[0] + device = tokens.device if noise is None: # Sample noise with padded dimension as expected by action_in_proj @@ -748,9 +743,7 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks - ) + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -773,7 +766,6 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ while time >= -dt / 2: expanded_time = time.expand(bsize) v_t = self.denoise_step( - state, prefix_pad_masks, past_key_values, x_t, @@ -786,14 +778,13 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ def denoise_step( self, - state, prefix_pad_masks, past_key_values, x_t, timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] @@ -1129,7 +1120,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): return images, img_masks - def _tokenize_language( + def _tokenize_language_and_state( self, batch: dict[str, Tensor] ) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language` """Tokenize language input using PaliGemma tokenizer.""" @@ -1149,25 +1140,44 @@ class PI05OpenPIPolicy(PreTrainedPolicy): batch_size = batch[next(iter(batch.keys()))].shape[0] tasks = ["Pick up the object"] * batch_size - # Tokenize with max_length padding to match OpenPI's expected format + # Handle discrete state input for PI05 (always the case for pi05) + # Get state from batch and discretize it + state: Any | None = batch.get(OBS_STATE) + if state is None: + raise ValueError("Robot state is required for PI05") + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.config.max_state_dim) + + # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`) + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + # Tokenize the full prompts with state + # Use the HuggingFace tokenizer properly (not .encode() which doesn't exist on AutoTokenizer) tokenized = self.tokenizer( - tasks, - padding="max_length", # Use max_length padding as per OpenPI - padding_side="right", # from lerobot pi0 `prepare_language` + full_prompts, + padding="max_length", + padding_side="right", truncation=True, - max_length=self.max_token_len, # Use the max token length from config + max_length=self.max_token_len, return_tensors="pt", + add_special_tokens=True, ) - lang_tokens = tokenized["input_ids"].to(device) - lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + tokens = tokenized["input_ids"].to(device) + masks = tokenized["attention_mask"].to(device, dtype=torch.bool) - return lang_tokens, lang_masks - - def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) - """Pad state""" - state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) - return state + return tokens, masks def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) """Pad action""" @@ -1196,11 +1206,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - lang_tokens, lang_masks = self._tokenize_language(batch) - state = self.prepare_state(batch) + tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05 - # Sample actions using the model - actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) + # Sample actions using the model (no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks) # Unpad actions to actual action dimension, works similar to lerobot pi0 `prepare_action` original_action_dim = self.config.output_features[ACTION].shape[0] @@ -1216,13 +1225,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - lang_tokens, lang_masks = self._tokenize_language(batch) + tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05 - state = self.prepare_state(batch) actions = self.prepare_action(batch) - # Compute loss - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0]