diff --git a/src/lerobot/policies/pi0fast/README.md b/src/lerobot/policies/pi0fast/README.md new file mode 100644 index 000000000..2ae69d978 --- /dev/null +++ b/src/lerobot/policies/pi0fast/README.md @@ -0,0 +1,49 @@ +# π₀.₅ (pi05) + +This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence. +It is designed as a **Vision-Language-Action model with open-world generalization**. + +--- + +## Model Overview + +| Feature | π₀ | π₀.₅ | +| -------------------- | ------------------------------------------------------ | ----------------------------------------- | +| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning | +| AdaRMS | Not used | Used in action expert | +| Tokenizer Length | 48 tokens | 200 tokens | +| Discrete State Input | False (Uses `state_proj` layer) | True | +| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) | + +--- + +## Citation + +If you use this work, please cite both **OpenPI** and the π₀.₅ paper: + +```bibtex +@misc{openpi2024, + author = {Physical Intelligence Lab}, + title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies}, + year = {2024}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/Physical-Intelligence/openpi}}, + license = {Apache-2.0} +} + +@misc{intelligence2025pi05visionlanguageactionmodelopenworld, + title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization}, + author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky}, + year = {2025}, + eprint = {2504.16054}, + archivePrefix= {arXiv}, + primaryClass = {cs.LG}, + url = {https://arxiv.org/abs/2504.16054}, +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/src/lerobot/policies/pi0fast/__init__.py b/src/lerobot/policies/pi0fast/__init__.py new file mode 100644 index 000000000..4f9a9de4a --- /dev/null +++ b/src/lerobot/policies/pi0fast/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi05 import PI05Config +from .modeling_pi05 import PI05Policy +from .processor_pi05 import make_pi05_pre_post_processors + +__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] diff --git a/src/lerobot/policies/pi0fast/configuration_pi05.py b/src/lerobot/policies/pi0fast/configuration_pi05.py new file mode 100644 index 000000000..7bdce70dd --- /dev/null +++ b/src/lerobot/policies/pi0fast/configuration_pi05.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig + +DEFAULT_IMAGE_SIZE = 224 + + +@PreTrainedConfig.register_subclass("pi05") +@dataclass +class PI05Config(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + n_obs_steps: int = 1 + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state + "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi0fast/modeling_pi05.py b/src/lerobot/policies/pi0fast/modeling_pi05.py new file mode 100644 index 000000000..72ac03a3e --- /dev/null +++ b/src/lerobot/policies/pi0fast/modeling_pi05.py @@ -0,0 +1,1100 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from typing_extensions import Unpack + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI05.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI05 PyTorch model.""" + + def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True], + precision=config.dtype, + image_size=config.image_resolution[0], + ) + + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) + + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI05Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI05Pytorch model") + + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + + def shift_padding_side( + self, + tokens: torch.Tensor, + ar_mask: torch.Tensor, + padding_mask: torch.Tensor, + loss_mask: torch.Tensor, + targets: torch.Tensor, + token_type_ids: torch.Tensor, + padding_side: str = "right", + ) -> tuple[torch.Tensor]: + if padding_side not in ["right", "left"]: + return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids + + new_tokens = torch.empty_like(tokens) + new_ar_masks = torch.empty_like(ar_mask) + new_padding_mask = torch.empty_like(padding_mask) + new_loss_mask = torch.empty_like(loss_mask) + new_targets = torch.empty_like(targets) + new_token_type_ids = torch.empty_like(token_type_ids) + batch_size = tokens.shape[0] + for i in range(batch_size): + padding_indices = torch.where(padding_mask[i] == 0)[0] + non_padding_indices = torch.where(padding_mask[i] == 1)[0] + if padding_side == "left": + new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) + else: + new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) + new_tokens[i] = tokens[i].index_select(0, new_indices) + new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) + new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) + new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) + new_targets[i] = targets[i].index_select(0, new_indices) + new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) + + return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids + + def embed_prefix( + self, images, img_masks, tokens, attention_mask, padded_mask + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(padded_mask) + + num_lang_embs = lang_emb.shape[1] + # att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = torch.cat( + [att_masks, attention_mask], dim=1 + ) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def forward(self, images, img_masks, tokens, masks) -> Tensor: + """Do a full training forward pass and compute the loss.""" + # tokens will contain the tokenized actions as well insisde + embs, pad_masks, att_masks = self.embed_prefix(images, img_masks, tokens, masks) + + # will add loss for ce token prediction here + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs = self.paligemma_with_expert.paligemma.forward( + input_ids=None, + token_type_ids=None, + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=embs, + use_cache=False, + labels=None, + ) + + logits = outputs.logits + loss_fct = nn.CrossEntropyLoss(reduction="none") + + device = embs.device + # Shift left for next-step prediction + logits = logits[:, :-1, :] + targets = targets[:, 1:].to(device) # Shift targets + loss_mask = loss_masks[:, 1:].to(device) # Ensure correct shape + + # Compute per-token loss + token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) + + # Compute per-token loss + token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) + + # Apply loss mask + token_loss = token_loss * loss_mask.reshape(-1) + + # Compute final loss + loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) + + # Return loss dictionary + return loss + + +class PI05Policy(PreTrainedPolicy): + """PI05 Policy for LeRobot.""" + + config_class = PI05Config + name = "pi05" + + def __init__( + self, + config: PI05Config, + **kwargs, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the core PI05 model + self.init_rtc_processor() + self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) + # now we must call .generate() method on the model + # then detoknize + # actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + actions = self.prepare_action(batch) + + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss_dict = { + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py new file mode 100644 index 000000000..33de3c0df --- /dev/null +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -0,0 +1,995 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ONLY AN EXAMPLE FILE, NEVER USED, IT IS OLD CODE +""" +π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models + +[Paper](https://huggingface.co/papers/2501.09747) +[Jax code](https://github.com/Physical-Intelligence/openpi) + +Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. + +Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): +```bash +lerobot-train \ +--policy.path=lerobot/pi0fast_base \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of training the pi0+FAST neural network with from scratch: +```bash +lerobot-train \ +--policy.type=pi0fast \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of using the pi0 pretrained model outside LeRobot training framework: +```python +policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") +``` + +""" + +from collections import deque +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from PIL import Image +from scipy.fft import idct +from torch import Tensor, nn +from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration +from transformers.cache_utils import HybridCache, StaticCache +from transformers.models.auto import CONFIG_MAPPING + +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.policies.pretrained import PreTrainedPolicy + +PRECISION = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class PI0FASTPolicy(PreTrainedPolicy): + """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" + + config_class = PI0FASTConfig + name = "pi0fast" + + def __init__( + self, + config: PI0FASTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") + self.model = PI0FAST(config) + + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + + def get_optim_params(self) -> dict: + return self.parameters() + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0FAST") + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + actions = self.model.generate_actions(batch) + + actions = actions[:, : self.config.n_action_steps] + + original_action_dim = self.config.action_feature.shape[ + 0 + ] # self.config.max_action_dim # self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + loss_dict = self.model.forward(batch) + return loss_dict["loss"], loss_dict + + +def block_causal_update_causal_mask( + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + attn_implementation: str = "eager", + dtype: torch.dtype = "float32", +): + """ + Update the causal mask during training and generation. It can be customized to different attention masks. + """ + if attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(dtype).min + + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + + if using_static_cache or isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + # Handle precomputed attention masks + if attention_mask is not None and attention_mask.dim() == 4: + return attention_mask + + # Causal mask initialization + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + + # Standard causal masking (triu ensures tokens can only attend to past) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + # Apply block causal mask + if token_type_ids is not None: + token_type_ids = token_type_ids.to(causal_mask.device).bool() + cumsum = torch.cumsum(token_type_ids, dim=1) + block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] + + # Combine causal_mask with block-wise attention mask + causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) + causal_mask = causal_mask[:, None, :, :] + else: + # Apply past cache position constraint + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + else: + # Apply past cache position constraint + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits + mask_length = attention_mask.shape[-1] + + # Apply padding mask + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +def prepare_inputs_for_generation( + # self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + num_logits_to_keep=None, + labels=None, + self=None, + **kwargs, +): + # create block causal attention + if cache_position[0] > 0 and input_ids.shape[1] > 0: + input_tensor = input_ids[:, -1:] + new_positions = ( + torch.ones( + (position_ids.shape[0], input_ids.shape[1]), + dtype=position_ids.dtype, + device=position_ids.device, + ).cumsum(-1) + + position_ids[:, -1:] + ) + position_ids = torch.cat([position_ids, new_positions], dim=-1) + else: + input_tensor = inputs_embeds + attention_mask = block_causal_update_causal_mask( + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + input_tensor=input_tensor, + token_type_ids=token_type_ids, + dtype=self.dtype, + attn_implementation=self.config.text_config._attn_implementation, + ) + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + num_logits_to_keep=num_logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # Position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + +class PI0FAST(nn.Module): + def __init__(self, config: PI0FASTConfig): + super().__init__() + self.config = config + + # TODO: move tokenizers in Policy + fast_tokenizer_path = "physical-intelligence/fast" + pi0_paligemma_path = "google/paligemma-3b-pt-224" + self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) + self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) + self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) + self.fast_skip_tokens = self.config.fast_skip_tokens + self.max_input_seq_len = self.config.max_input_seq_len + self.action_horizon = self.config.chunk_size + self.action_dim = self.config.action_feature.shape[ + 0 + ] # self.config.max_action_dim # self.config.action_feature.shape[0] + precision = config.precision + torch_precision = PRECISION.get(precision, torch.float32) + self.pad_token_id = ( + self.paligemma_tokenizer.pad_token_id + if hasattr(self.paligemma_tokenizer, "pad_token_id") + else self.paligemma_tokenizer.eos_token_id + ) + + paligemma_config = CONFIG_MAPPING["paligemma"]( + transformers_version="4.48.1", + _vocab_size=257152, + bos_token_id=2, + eos_token_id=1, + hidden_size=2048, + image_token_index=257152, + model_type="paligemma", + pad_token_id=0, + projection_dim=2048, + text_config={ + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2048, + "intermediate_size": 16384, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_image_tokens": 256, + "num_key_value_heads": 1, + "torch_dtype": precision, + "vocab_size": 257152, + "_attn_implementation": "eager", + }, + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_image_tokens": 256, + "patch_size": 14, + "projection_dim": 2048, + "projector_hidden_act": "gelu_pytorch_tanh", + "torch_dtype": precision, + "vision_use_head": False, + }, + ) + self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) + + self.pi0_paligemma.prepare_inputs_for_generation = partial( + prepare_inputs_for_generation, self=self.pi0_paligemma + ) + # change important stuff in bf16 + params_to_change_dtype = [ + "language_model", + "vision_tower", + "multi_modal", + ] + for name, param in self.pi0_paligemma.named_parameters(): + if any(selector in name for selector in params_to_change_dtype): + param.data = param.data.to(dtype=torch_precision) + self.set_requires_grad() + self.image_keys = self.config.image_features.keys() + # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed + # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index' + self.ignore_index = self.pi0_paligemma.config.ignore_index + self.padding_side = self.config.padding_side + + def set_requires_grad(self): + if self.config.freeze_vision_encoder: + self.pi0_paligemma.vision_tower.eval() + for params in self.pi0_paligemma.vision_tower.parameters(): + params.requires_grad = False + # To avoid unused params issue with distributed training + if self.config.freeze_lm_head: + for name, params in self.pi0_paligemma.named_parameters(): + if "embed_tokens" in name: # lm heads and embedding layer are tied + params.requires_grad = False + + def embed_tokens(self, tokens: torch.Tensor): + return self.pi0_paligemma.language_model.model.embed_tokens(tokens) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) + + def prepare_images(self, batch): + """Preprocess LeRobot batch into Pi0 inputs""" + images = [] + img_masks = [] + present_img_keys = [key for key in self.image_keys if key in batch] + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + # Preprocess image features present in the batch + num_empty_cameras = 0 + for key in self.image_keys: + if key in present_img_keys: + img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad( + img, + *self.config.resize_imgs_with_padding, + pad_value=0, + interpolate_like_pi=self.config.interpolate_like_pi, + ) + + # Normalize from range [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + else: + if num_empty_cameras >= self.config.empty_cameras: + continue + img = torch.ones_like(img) * -1 + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + num_empty_cameras += 1 + + images.append(img) + img_masks.append(mask) + return images, img_masks + + def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: + mins = actions.amin(dim=(1, 2), keepdim=True) # [0] + maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] + return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 + + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens + return out + + def fast_tokenizer_wrapper(self, actions_norm): + """ + A wrapper for self.fast_tokenizer that ensures batch processing, + conversion to PyTorch tensors, and returns a dictionary without padding. + """ + batch_tokens = self.fast_tokenizer(actions_norm) + fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") + + return fast_out + + def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: + token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) + # Compute cumulative sum mask + cumsum_mask = (padded_mask != 0).cumsum(dim=1) + # Suffix block (everything after prefix_len) + suffix_mask = cumsum_mask > prefix_len + token_type_ids = suffix_mask + return token_type_ids + + def create_input_tokens(self, state, lang_text, actions=None): + bsize = state.shape[0] + device = state.device + bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] + discretized = torch.bucketize(state, bins) - 1 + discretized = discretized[:, :32] + + prefix_texts = [] + state_text = [] + for txt, disc in zip(lang_text, discretized, strict=False): + cleaned = txt.lower().strip().replace("_", " ") + state_str = " ".join(str(val.item()) for val in disc) + prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") + state_text.append(f"State: {state_str};\n") + + prefix_out = self.paligemma_tokenizer( + prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False + ) + prefix_ids = prefix_out["input_ids"].to(device) + prefix_mask = prefix_out["attention_mask"].to(device) + prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() + + if actions is not None: + actions_norm = self.normalize_actions(actions) + actions_pad = F.pad( + actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 + )[:, :, : self.config.max_action_dim] + fast_out = self.fast_tokenizer_wrapper( + actions_pad.cpu(), + ) + act_ids = fast_out["input_ids"] + act_mask = fast_out["attention_mask"].to(device) + + act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) + # Replace action with 0 to pad tokens + act_ids = torch.where( + act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, + self.pad_token_id, + act_ids, + ) + + eos_token = torch.tensor( + [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device + ).expand(bsize, -1) + eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) + bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") + bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) + bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) + act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) + act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) + act_mask = act_mask.to(device) + else: + act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) + act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) + final_ids = torch.cat([prefix_ids, act_ids], dim=1) + + final_mask = torch.cat([prefix_mask, act_mask], dim=1) + batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} + + # Use tokenizer pad function + padded_output = self.paligemma_tokenizer.pad( + batch_inputs, padding="longest", max_length=180, return_tensors="pt" + ) + padded_mask = padded_output["attention_mask"] + + # define tensor of padding lengths + att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens + + token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) + + padded_output["padded_mask"] = padded_output.pop("attention_mask") + padded_output["attention_mask"] = att_mask + # loss is computed not on prefix, and not on padding + padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] + padded_output["token_type_ids"] = token_type_ids + return padded_output + + def shift_padding_side( + self, + tokens: torch.Tensor, + ar_mask: torch.Tensor, + padding_mask: torch.Tensor, + loss_mask: torch.Tensor, + targets: torch.Tensor, + token_type_ids: torch.Tensor, + padding_side: str = "right", + ) -> tuple[torch.Tensor]: + if padding_side not in ["right", "left"]: + return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids + + new_tokens = torch.empty_like(tokens) + new_ar_masks = torch.empty_like(ar_mask) + new_padding_mask = torch.empty_like(padding_mask) + new_loss_mask = torch.empty_like(loss_mask) + new_targets = torch.empty_like(targets) + new_token_type_ids = torch.empty_like(token_type_ids) + batch_size = tokens.shape[0] + for i in range(batch_size): + padding_indices = torch.where(padding_mask[i] == 0)[0] + non_padding_indices = torch.where(padding_mask[i] == 1)[0] + if padding_side == "left": + new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) + else: + new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) + new_tokens[i] = tokens[i].index_select(0, new_indices) + new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) + new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) + new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) + new_targets[i] = targets[i].index_select(0, new_indices) + new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) + + return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids + + def forward(self, batch: dict[str, Tensor]): + device = batch[OBS_STATE].device + # TODO: keep like this or move to the policy .forward + images, img_masks = self.prepare_images(batch) + + padded_outs = self.create_input_tokens( + state=batch[OBS_STATE], + lang_text=batch["task"], + actions=batch[ACTION], + ) + + embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( + images, + img_masks, + padded_outs["input_ids"], + padded_outs["padded_mask"], + padded_outs["attention_mask"], + padded_outs["loss_mask"], + padded_outs["token_type_ids"], + padding_side=self.padding_side, + ) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + token_type_ids = token_type_ids.to(dtype=torch.int64) + past_seen_tokens = 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) + pad_masks = block_causal_update_causal_mask( + attention_mask=pad_masks, + past_key_values=None, + cache_position=cache_position, + input_tensor=embs, + token_type_ids=token_type_ids, + dtype=self.pi0_paligemma.dtype, + attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, + ) + outputs = self.pi0_paligemma.forward( + input_ids=None, + token_type_ids=None, + attention_mask=pad_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=embs, + use_cache=False, + labels=None, + ) + + logits = outputs.logits + + loss_fct = nn.CrossEntropyLoss(reduction="none") + + # Shift left for next-step prediction + logits = logits[:, :-1, :] + targets = targets[:, 1:].to(device) # Shift targets + loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape + + # Compute per-token loss + token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) + + # Apply loss mask + token_loss = token_loss * loss_mask.reshape(-1) + + # Compute final loss + loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) + + # Return loss dictionary + loss_dict = {"ce_loss": loss.item(), "loss": loss} + return loss_dict + + def decode_actions_with_fast( + self, + tokens: list[list[int]], + *, + time_horizon: int | None = None, + action_dim: int | None = None, + relaxed_decoding: bool = True, + ) -> np.array: + """ + Adapt original decoding in FAST to always return actions instead of zeros. + """ + self.time_horizon = ( + time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon + ) + self.action_dim = ( + action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim + ) + + # Cache the time horizon and action dimension for the next call + self.called_time_horizon = self.time_horizon + self.called_action_dim = self.action_dim + + assert self.time_horizon is not None and self.action_dim is not None, ( + "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + ) + + decoded_actions = [] + for token in tokens: + try: + decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token + if relaxed_decoding: + # Expected sequence length + expected_seq_len = self.time_horizon * self.action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + # Apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right + # Apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) + + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) + assert decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + ) + except Exception as e: + print(f"Error decoding tokens: {e}") + print(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) + decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) + return np.stack(decoded_actions) + + def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: + """ + Extracts actions from predicted output tokens using the FAST model. + + Args: + tokens (torch.Tensor): The input tensor of tokenized outputs. + action_horizon (int): The number of timesteps for actions. + action_dim (int): The dimensionality of each action. + + Returns: + torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). + """ + # Decode predicted output tokens + decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) + cleaned_tokens = [ + tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() + for tokens_sequence in decoded_tokens + ] + raw_action_tokens = [ + self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) + for sample_tokens in cleaned_tokens + ] # something like this should be robust #looks good + action_tokens = [ + self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens + ] + # returns the tensor of decoded actions per sample in a list + decoded_actions = [ + torch.tensor( + self.decode_actions_with_fast( + tok.tolist(), + time_horizon=action_horizon, + action_dim=action_dim, + relaxed_decoding=self.config.relaxed_action_decoding, + ), + device=tokens.device, + ).squeeze(0) + for tok in action_tokens + ] + + return torch.stack( + decoded_actions, + dim=0, + ) + + def generate_actions(self, batch: dict[str, Tensor]): + # TODO: keep like this or move to the policy .forward + images, img_masks = self.prepare_images(batch) + + padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) + embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( + images, + img_masks, + padded_outs["input_ids"], + padded_outs["padded_mask"], + padded_outs["attention_mask"], + padded_outs["loss_mask"], + padded_outs["token_type_ids"], + padding_side="left", + ) + token_type_ids = token_type_ids.to(dtype=torch.int64) + prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 + output_tokens = self.pi0_paligemma.generate( + input_ids=None, + attention_mask=pad_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=embs, + use_cache=self.config.use_cache, + max_new_tokens=self.config.max_decoding_steps, + do_sample=False, + num_beams=1, + token_type_ids=token_type_ids, + ) + actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) + return actions + + def embed_image(self, image: torch.Tensor): + # Handle different transformers versions + if hasattr(self.pi0_paligemma, "get_image_features"): + return self.pi0_paligemma.get_image_features(image) + else: + return self.pi0_paligemma.model.get_image_features(image) + + def embed_inputs( + self, + images, + img_masks, + tokens, + pad_mask, + ar_mask, + loss_mask, + token_type_ids, + padding_side: str = "right", + ): + # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + # images are a list of same size + # vectorizing everything! + device = images[0].device + image_embedding_dim = images[0].shape[-1] # TODO should be from self.config + all_images = torch.stack(images, dim=1).to(device) + b, n, c, h, w = all_images.shape + all_images = all_images.view(b * n, c, h, w) + embedded = self.embed_image(all_images).to(device) + b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions + m = b_n // b # Compute the number of images per sample dynamically + + # Reshape dynamically + embedded = embedded.view(b, m, p, image_embedding_dim) + tokens_embs = self.embed_tokens(tokens.to(device)) + + img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) + num_img_emb = embedded.shape[2] + img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) + img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) + + image_target_tokens = ( + torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id + ).reshape(b, -1) + image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) + + embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) + + embs = torch.cat([embedded, tokens_embs], dim=1).to(device) + pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) + att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) + loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) + targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) + token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) + + # Shift pad tokens to the left (.generate()) or right (.train()) + embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( + embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side + ) + + targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) + return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids + + +def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + if interpolate_like_pi: + img = (img * 255.0).to(dtype=torch.uint8) + img = img.permute(0, 2, 3, 1) + original_device = img.device + img = img.to(device="cpu").numpy() + imgs = [] + for sub_img in img: + sub_img = Image.fromarray(sub_img) + resized_img = sub_img.resize((resized_width, resized_height), resample=2) + resized_img = torch.from_numpy(np.array(resized_img)) + imgs.append(resized_img) + img = torch.stack(imgs, dim=0) + img = img.permute(0, 3, 1, 2) + resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 + else: + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img \ No newline at end of file diff --git a/src/lerobot/policies/pi0fast/processor_pi05.py b/src/lerobot/policies/pi0fast/processor_pi05.py new file mode 100644 index 000000000..e29bc4c23 --- /dev/null +++ b/src/lerobot/policies/pi0fast/processor_pi05.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05.modeling_pi05 import pad_vector +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +@dataclass +class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI05") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi05_pre_post_processors( + config: PI05Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 2ef89c107..25276f9c3 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -27,13 +27,14 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch +import torch.nn.functional as F from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey -from .pipeline import ObservationProcessorStep, ProcessorStepRegistry +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry, ProcessorStep # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: @@ -268,3 +269,328 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) return features + + +@dataclass +@ProcessorStepRegistry.register(name="pi0fast_tokenizer_processor") +class PI0FASTTokenizerProcessorStep(ProcessorStep): + """ + Processor step to tokenize state, language, and actions for PI0FAST models. + + This step handles the complete tokenization pipeline for PI0FAST: + 1. Discretizes state observations + 2. Formats task descriptions with state + 3. Tokenizes actions using the FAST tokenizer + 4. Combines everything into the proper format with masks + + Example usage: + ```python + from transformers import AutoTokenizer, AutoProcessor + from lerobot.processor.tokenizer_processor import PI0FASTTokenizerProcessorStep + + # Initialize tokenizers + paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + paligemma_processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") + fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + + # Create processor step + processor = PI0FASTTokenizerProcessorStep( + paligemma_tokenizer=paligemma_tokenizer, + fast_tokenizer=fast_tokenizer, + paligemma_processor=paligemma_processor, + max_action_dim=7, + fast_skip_tokens=2, + max_input_seq_len=180, + task_key="task", + state_key="observation.state" + ) + + # Apply to a transition + tokenized_transition = processor(transition) + + # Access tokenized data from observation + input_ids = tokenized_transition["observation"]["pi0fast_input_ids"] + attention_mask = tokenized_transition["observation"]["pi0fast_attention_mask"] + loss_mask = tokenized_transition["observation"]["pi0fast_loss_mask"] + token_type_ids = tokenized_transition["observation"]["pi0fast_token_type_ids"] + ``` + + Attributes: + paligemma_tokenizer: The PaliGemma tokenizer for text + fast_tokenizer: The FAST tokenizer for actions + paligemma_processor: The PaliGemma processor + max_action_dim: Maximum dimension for actions (default: 7) + fast_skip_tokens: Number of tokens to skip in FAST tokenizer mapping (default: 2) + max_input_seq_len: Maximum input sequence length (default: 180) + padding_side: The side to pad on ('left' or 'right', default: 'right') + task_key: The key in complementary_data where the task string is stored (default: 'task') + state_key: The key in observation where the state is stored (default: 'observation.state') + """ + + paligemma_tokenizer: Any = None + fast_tokenizer: Any = None + paligemma_processor: Any = None + max_action_dim: int = 7 + fast_skip_tokens: int = 2 + max_input_seq_len: int = 180 + padding_side: str = "right" + task_key: str = "task" + state_key: str = OBS_STATE + + def __post_init__(self): + """Initialize the tokenizers.""" + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use PI0FASTTokenizerProcessorStep." + ) + + if self.paligemma_tokenizer is None or self.fast_tokenizer is None or self.paligemma_processor is None: + raise ValueError( + "paligemma_tokenizer, fast_tokenizer, and paligemma_processor must all be provided. " + "These should be initialized tokenizer/processor objects." + ) + + def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: + """Normalize actions to [-1, 1] range per batch element.""" + mins = actions.amin(dim=(1, 2), keepdim=True) + maxs = actions.amax(dim=(1, 2), keepdim=True) + return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 + + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """Convert FAST tokens to PaliGemma vocabulary space.""" + vocab_size = getattr(self.paligemma_tokenizer, "vocab_size", 257152) + return vocab_size - 1 - self.fast_skip_tokens - tokens + + def fast_tokenizer_wrapper(self, actions_norm): + """Wrapper for FAST tokenizer that ensures batch processing and returns PyTorch tensors.""" + batch_tokens = self.fast_tokenizer(actions_norm) + fast_out = self.paligemma_processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") + return fast_out + + def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: torch.Tensor) -> torch.Tensor: + """Create token type IDs to distinguish prefix from action tokens.""" + token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) + cumsum_mask = (padded_mask != 0).cumsum(dim=1) + suffix_mask = cumsum_mask > prefix_len + return suffix_mask + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Process the transition and add tokenized inputs. + + Args: + transition: The environment transition to process + + Returns: + The transition with added tokenized data + """ + self.transition = transition + + # Extract components from transition + observation = transition.get(TransitionKey.OBSERVATION) + action = transition.get(TransitionKey.ACTION) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + + if observation is None: + raise ValueError("Observation is None in transition") + + # Get state and language + state = observation.get(self.state_key) + if state is None: + raise ValueError(f"State key '{self.state_key}' not found in observation") + + # Get task description + if complementary_data is None: + raise ValueError("Complementary data is None, cannot extract task") + + task_data = complementary_data.get(self.task_key) + if task_data is None: + raise ValueError(f"Task key '{self.task_key}' not found in complementary data") + + # Standardize task to list of strings + if isinstance(task_data, str): + lang_text = [task_data] + elif isinstance(task_data, list) and all(isinstance(t, str) for t in task_data): + lang_text = task_data + else: + raise ValueError(f"Task must be string or list of strings, got {type(task_data)}") + + # Create tokenized inputs + tokenized_data = self.create_input_tokens(state, lang_text, action) + + # Add tokenized data to observation + new_observation = dict(observation) + new_observation["pi0fast_input_ids"] = tokenized_data["input_ids"] + new_observation["pi0fast_attention_mask"] = tokenized_data["attention_mask"] + new_observation["pi0fast_padded_mask"] = tokenized_data["padded_mask"] + new_observation["pi0fast_loss_mask"] = tokenized_data["loss_mask"] + new_observation["pi0fast_token_type_ids"] = tokenized_data["token_type_ids"] + + # Create new transition with updated observation + new_transition = dict(transition) + new_transition[TransitionKey.OBSERVATION] = new_observation + + return new_transition + + def create_input_tokens(self, state, lang_text, actions=None): + """ + Create tokenized input from state, language, and actions. + + This method follows the same logic as the original PI0FAST create_input_tokens method. + + Args: + state: State tensor [batch_size, state_dim] + lang_text: List of task description strings + actions: Optional action tensor [batch_size, horizon, action_dim] + + Returns: + Dictionary containing input_ids, attention_mask, padded_mask, loss_mask, and token_type_ids + """ + bsize = state.shape[0] + device = state.device + + # Discretize state + bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] + discretized = torch.bucketize(state, bins) - 1 + discretized = discretized[:, :32] + + # Create prefix texts with task and state + prefix_texts = [] + for txt, disc in zip(lang_text, discretized, strict=False): + cleaned = txt.lower().strip().replace("_", " ") + state_str = " ".join(str(val.item()) for val in disc) + prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") + + # Tokenize prefix + prefix_out = self.paligemma_tokenizer( + prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False + ) + prefix_ids = prefix_out["input_ids"].to(device) + prefix_mask = prefix_out["attention_mask"].to(device) + prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() + + # Get pad token ID + pad_token_id = ( + self.paligemma_tokenizer.pad_token_id + if hasattr(self.paligemma_tokenizer, "pad_token_id") + else self.paligemma_tokenizer.eos_token_id + ) + + if actions is not None: + # pad actions + actions_pad = F.pad( + actions, (0, max(0, self.max_action_dim - actions.shape[2])), value=0 + )[:, :, : self.max_action_dim] + + # Tokenize actions with FAST tokenizer + fast_out = self.fast_tokenizer_wrapper(actions_pad.cpu()) + act_ids = fast_out["input_ids"] + act_mask = fast_out["attention_mask"].to(device) + + # Convert FAST tokens to PaliGemma token space + act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) + + # Replace padding tokens + vocab_size = getattr(self.paligemma_tokenizer, "vocab_size", 257152) + act_ids = torch.where( + act_ids == vocab_size - 1 - self.fast_skip_tokens, + pad_token_id, + act_ids, + ) + + # Add BOS ("Action: ") and EOS tokens + eos_token = torch.tensor( + [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device + ).expand(bsize, -1) + eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) + + bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") + bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) + bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) + + act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) + act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) + act_mask = act_mask.to(device) + else: + # No actions provided + act_ids = torch.empty(bsize, 0, dtype=torch.long, device=device) + act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) + + # Concatenate prefix and action tokens + final_ids = torch.cat([prefix_ids, act_ids], dim=1) + final_mask = torch.cat([prefix_mask, act_mask], dim=1) + + batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} + + # Pad to max length + padded_output = self.paligemma_tokenizer.pad( + batch_inputs, padding="longest", max_length=self.max_input_seq_len, return_tensors="pt" + ) + padded_mask = padded_output["attention_mask"] + + # Create attention mask (excludes prefix) + att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens + + # Create token type IDs + token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) + + # Return all masks + return { + "input_ids": padded_output["input_ids"], + "attention_mask": att_mask, + "padded_mask": padded_mask, + "loss_mask": att_mask & padded_mask, # loss is computed not on prefix, and not on padding + "token_type_ids": token_type_ids, + } + + def get_config(self) -> dict[str, Any]: + """Returns the serializable configuration of the processor.""" + return { + "max_action_dim": self.max_action_dim, + "fast_skip_tokens": self.fast_skip_tokens, + "max_input_seq_len": self.max_input_seq_len, + "padding_side": self.padding_side, + "task_key": self.task_key, + "state_key": self.state_key, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Adds feature definitions for the tokenized PI0FAST inputs. + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + # Add features for tokenized inputs + if "pi0fast_input_ids" not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION]["pi0fast_input_ids"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,) + ) + + if "pi0fast_attention_mask" not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION]["pi0fast_attention_mask"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,) + ) + + if "pi0fast_padded_mask" not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION]["pi0fast_padded_mask"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,) + ) + + if "pi0fast_loss_mask" not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION]["pi0fast_loss_mask"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,) + ) + + if "pi0fast_token_type_ids" not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION]["pi0fast_token_type_ids"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,) + ) + + return features