mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
d656da8ccc
Move PI0 and PI0.5 noise/time sampling into the policy wrappers so the compiled PyTorch cores receive them as tensor inputs. This keeps Beta sampling out of torch.compile on MPS, avoiding aten::_sample_dirichlet compilation errors while preserving the CUDA training path. Validation: .venv/bin/python -m pre_commit run --files src/lerobot/policies/pi0/modeling_pi0.py src/lerobot/policies/pi05/modeling_pi05.py; .venv/bin/python -m pytest -sv -rs tests/policies/pi0_pi05/test_pi0.py tests/policies/pi0_pi05/test_pi05.py tests/policies/pi0_pi05/test_pi0_rtc.py tests/policies/pi0_pi05/test_pi05_rtc.py Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
1294 lines
52 KiB
Python
1294 lines
52 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import builtins
|
|
import copy
|
|
import logging
|
|
import math
|
|
from collections import deque
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
|
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
|
|
|
# Conditional import for type checking and lazy loading
|
|
if TYPE_CHECKING or _transformers_available:
|
|
from transformers.models.auto import CONFIG_MAPPING
|
|
from transformers.models.gemma import modeling_gemma
|
|
|
|
from ..pi_gemma import (
|
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
|
PiGemmaForCausalLM,
|
|
_gated_residual,
|
|
layernorm_forward,
|
|
)
|
|
else:
|
|
CONFIG_MAPPING = None
|
|
modeling_gemma = None
|
|
PiGemmaForCausalLM = None
|
|
_gated_residual = None
|
|
layernorm_forward = None
|
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
|
from lerobot.configs import PreTrainedConfig
|
|
from lerobot.utils.constants import (
|
|
ACTION,
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
|
OBS_LANGUAGE_TOKENS,
|
|
OPENPI_ATTENTION_MASK_VALUE,
|
|
)
|
|
|
|
from ..pretrained import PreTrainedPolicy, T
|
|
from ..rtc.modeling_rtc import RTCProcessor
|
|
from .configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
|
|
|
|
|
class ActionSelectKwargs(TypedDict, total=False):
|
|
inference_delay: int | None
|
|
prev_chunk_left_over: Tensor | None
|
|
execution_horizon: int | None
|
|
|
|
|
|
def get_safe_dtype(target_dtype, device_type):
|
|
"""Get a safe dtype for the given device type."""
|
|
if device_type == "mps" and target_dtype == torch.float64:
|
|
return torch.float32
|
|
if device_type == "cpu":
|
|
# CPU doesn't support bfloat16, use float32 instead
|
|
if target_dtype == torch.bfloat16:
|
|
return torch.float32
|
|
if target_dtype == torch.float64:
|
|
return torch.float64
|
|
return target_dtype
|
|
|
|
|
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
|
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
|
) -> Tensor:
|
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
|
if dimension % 2 != 0:
|
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
|
|
if time.ndim != 1:
|
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
|
|
dtype = get_safe_dtype(torch.float64, device.type)
|
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
period = min_period * (max_period / min_period) ** fraction
|
|
|
|
# Compute the outer product
|
|
scaling_factor = 1.0 / period * 2 * math.pi
|
|
sin_input = scaling_factor[None, :] * time[:, None]
|
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
|
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
|
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
|
beta_t = torch.tensor(beta, dtype=torch.float32)
|
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
|
return dist.sample((bsize,)).to(device)
|
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
|
"""Copied from big_vision.
|
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
setup several types of attention, for example:
|
|
|
|
[[1 1 1 1 1 1]]: pure causal attention.
|
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
themselves and the last 3 tokens have a causal attention. The first
|
|
entry could also be a 1 without changing behaviour.
|
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
block can attend all previous blocks and all tokens on the same block.
|
|
|
|
Args:
|
|
input_mask: bool[B, N] true if its part of the input, false if padding.
|
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
it and 0 where it shares the same attention mask as the previous token.
|
|
"""
|
|
if att_masks.ndim != 2:
|
|
raise ValueError(att_masks.ndim)
|
|
if pad_masks.ndim != 2:
|
|
raise ValueError(pad_masks.ndim)
|
|
|
|
cumsum = torch.cumsum(att_masks, dim=1)
|
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
return att_2d_masks & pad_2d_masks
|
|
|
|
|
|
def pad_vector(vector, new_dim):
|
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
|
|
|
Can be (batch_size x sequence_length x features_dimension)
|
|
or (batch_size x features_dimension)
|
|
"""
|
|
if vector.shape[-1] >= new_dim:
|
|
return vector
|
|
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
|
|
|
|
|
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|
images: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
mode: str = "bilinear",
|
|
) -> torch.Tensor:
|
|
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
|
|
Args:
|
|
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
|
height: Target height
|
|
width: Target width
|
|
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
|
|
Returns:
|
|
Resized and padded tensor with same shape format as input
|
|
"""
|
|
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
|
if images.shape[-1] <= 4: # Assume channels-last format
|
|
channels_last = True
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
|
else:
|
|
channels_last = False
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
|
|
batch_size, channels, cur_height, cur_width = images.shape
|
|
|
|
# Calculate resize ratio
|
|
ratio = max(cur_width / width, cur_height / height)
|
|
resized_height = int(cur_height / ratio)
|
|
resized_width = int(cur_width / ratio)
|
|
|
|
# Resize
|
|
resized_images = F.interpolate(
|
|
images,
|
|
size=(resized_height, resized_width),
|
|
mode=mode,
|
|
align_corners=False if mode == "bilinear" else None,
|
|
)
|
|
|
|
# Handle dtype-specific clipping
|
|
if images.dtype == torch.uint8:
|
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
|
elif images.dtype == torch.float32:
|
|
resized_images = resized_images.clamp(0.0, 1.0)
|
|
else:
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
|
# Calculate padding
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
|
pad_h1 = pad_h0 + remainder_h
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
|
pad_w1 = pad_w0 + remainder_w
|
|
|
|
# Pad
|
|
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
|
padded_images = F.pad(
|
|
resized_images,
|
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
|
mode="constant",
|
|
value=constant_value,
|
|
)
|
|
|
|
# Convert back to original format if needed
|
|
if channels_last:
|
|
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
|
|
return padded_images
|
|
|
|
|
|
# Define the complete layer computation function for gradient checkpointing
|
|
def compute_layer_complete(
|
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
|
):
|
|
models = [paligemma.model.language_model, gemma_expert.model]
|
|
query_states = []
|
|
key_states = []
|
|
value_states = []
|
|
gates = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
|
gates.append(gate)
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
query_states.append(query_state)
|
|
key_states.append(key_state)
|
|
value_states.append(value_state)
|
|
# Concatenate and process attention
|
|
query_states = torch.cat(query_states, dim=2)
|
|
key_states = torch.cat(key_states, dim=2)
|
|
value_states = torch.cat(value_states, dim=2)
|
|
dummy_tensor = torch.zeros(
|
|
query_states.shape[0],
|
|
query_states.shape[2],
|
|
query_states.shape[-1],
|
|
device=query_states.device,
|
|
dtype=query_states.dtype,
|
|
)
|
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
|
)
|
|
batch_size = query_states.shape[0]
|
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
|
# Attention computation
|
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
scaling,
|
|
)
|
|
# Get head_dim from the current layer, not from the model
|
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
|
# Process layer outputs
|
|
outputs_embeds = []
|
|
start_pos = 0
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
end_pos = start_pos + hidden_states.shape[1]
|
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
|
# first residual
|
|
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
|
after_first_residual = out_emb.clone()
|
|
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
|
out_emb = layer.mlp(out_emb)
|
|
# second residual
|
|
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
|
outputs_embeds.append(out_emb)
|
|
start_pos = end_pos
|
|
return outputs_embeds
|
|
|
|
|
|
class GemmaConfig: # see openpi `gemma.py: Config`
|
|
"""Configuration for Gemma model variants."""
|
|
|
|
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
|
self.width = width
|
|
self.depth = depth
|
|
self.mlp_dim = mlp_dim
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.head_dim = head_dim
|
|
|
|
|
|
def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
|
|
"""Returns config for specified gemma variant."""
|
|
if variant == "gemma_300m":
|
|
return GemmaConfig(
|
|
width=1024,
|
|
depth=18,
|
|
mlp_dim=4096,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
elif variant == "gemma_2b":
|
|
return GemmaConfig(
|
|
width=2048,
|
|
depth=18,
|
|
mlp_dim=16_384,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown variant: {variant}")
|
|
|
|
|
|
class PaliGemmaWithExpertModel(
|
|
nn.Module
|
|
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
|
"""PaliGemma model with action expert for PI05."""
|
|
|
|
def __init__(
|
|
self,
|
|
vlm_config,
|
|
action_expert_config,
|
|
use_adarms=None,
|
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
|
freeze_vision_encoder: bool = False,
|
|
train_expert_only: bool = False,
|
|
):
|
|
if use_adarms is None:
|
|
use_adarms = [False, False]
|
|
super().__init__()
|
|
self.freeze_vision_encoder = freeze_vision_encoder
|
|
self.train_expert_only = train_expert_only
|
|
|
|
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
|
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
|
vlm_config_hf.image_token_index = 257152
|
|
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
|
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
|
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
|
vlm_config_hf.text_config.dtype = "float32"
|
|
vlm_config_hf.text_config.vocab_size = 257152
|
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
|
vlm_config_hf.vision_config.image_size = image_size
|
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
|
vlm_config_hf.vision_config.projection_dim = 2048
|
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
|
vlm_config_hf.vision_config.dtype = "float32"
|
|
|
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
|
head_dim=action_expert_config.head_dim,
|
|
hidden_size=action_expert_config.width,
|
|
intermediate_size=action_expert_config.mlp_dim,
|
|
num_attention_heads=action_expert_config.num_heads,
|
|
num_hidden_layers=action_expert_config.depth,
|
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
|
vocab_size=257152,
|
|
hidden_activation="gelu_pytorch_tanh",
|
|
dtype="float32",
|
|
use_adarms=use_adarms[1],
|
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
|
)
|
|
|
|
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
|
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
|
self.gemma_expert.model.embed_tokens = None
|
|
|
|
self.to_bfloat16_for_selected_params(precision)
|
|
self._set_requires_grad()
|
|
|
|
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
|
if precision == "bfloat16":
|
|
self.to(dtype=torch.bfloat16)
|
|
elif precision == "float32":
|
|
self.to(dtype=torch.float32)
|
|
return
|
|
else:
|
|
raise ValueError(f"Invalid precision: {precision}")
|
|
|
|
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
|
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
|
params_to_keep_float32 = [
|
|
"vision_tower",
|
|
"multi_modal_projector",
|
|
"input_layernorm",
|
|
"post_attention_layernorm",
|
|
"model.norm",
|
|
]
|
|
|
|
for name, param in self.named_parameters():
|
|
if any(selector in name for selector in params_to_keep_float32):
|
|
param.data = param.data.to(dtype=torch.float32)
|
|
|
|
def _set_requires_grad(self):
|
|
if self.freeze_vision_encoder:
|
|
self.paligemma.model.vision_tower.eval()
|
|
for param in self.paligemma.model.vision_tower.parameters():
|
|
param.requires_grad = False
|
|
if self.train_expert_only:
|
|
self.paligemma.eval()
|
|
for param in self.paligemma.parameters():
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode: bool = True):
|
|
super().train(mode)
|
|
if self.freeze_vision_encoder:
|
|
self.paligemma.model.vision_tower.eval()
|
|
if self.train_expert_only:
|
|
self.paligemma.eval()
|
|
|
|
def embed_image(self, image: torch.Tensor):
|
|
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
|
out_dtype = image.dtype
|
|
if image.dtype != torch.float32:
|
|
image = image.to(torch.float32)
|
|
image_outputs = self.paligemma.model.get_image_features(image)
|
|
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
|
if features.dtype != out_dtype:
|
|
features = features.to(out_dtype)
|
|
return features
|
|
|
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
|
return self.paligemma.model.language_model.embed_tokens(tokens)
|
|
|
|
def forward(
|
|
self,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: list[torch.FloatTensor] | None = None,
|
|
inputs_embeds: list[torch.FloatTensor] | None = None,
|
|
use_cache: bool | None = None,
|
|
adarms_cond: list[torch.Tensor] | None = None,
|
|
):
|
|
if adarms_cond is None:
|
|
adarms_cond = [None, None]
|
|
if inputs_embeds[1] is None:
|
|
prefix_output = self.paligemma.model.language_model.forward(
|
|
inputs_embeds=inputs_embeds[0],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
|
)
|
|
prefix_past_key_values = prefix_output.past_key_values
|
|
prefix_output = prefix_output.last_hidden_state
|
|
suffix_output = None
|
|
elif inputs_embeds[0] is None:
|
|
suffix_output = self.gemma_expert.model.forward(
|
|
inputs_embeds=inputs_embeds[1],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
|
)
|
|
suffix_output = suffix_output.last_hidden_state
|
|
prefix_output = None
|
|
prefix_past_key_values = None
|
|
else:
|
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
|
|
# Check if gradient checkpointing is enabled for any of the models
|
|
use_gradient_checkpointing = (
|
|
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
|
and self.gemma_expert.model.gradient_checkpointing
|
|
and self.training
|
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
|
|
|
# Process all layers with gradient checkpointing if enabled
|
|
for layer_idx in range(num_layers):
|
|
if use_gradient_checkpointing:
|
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_layer_complete,
|
|
layer_idx,
|
|
inputs_embeds,
|
|
attention_mask,
|
|
position_ids,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
paligemma=self.paligemma,
|
|
gemma_expert=self.gemma_expert,
|
|
)
|
|
else:
|
|
inputs_embeds = compute_layer_complete(
|
|
layer_idx,
|
|
inputs_embeds,
|
|
attention_mask,
|
|
position_ids,
|
|
adarms_cond,
|
|
paligemma=self.paligemma,
|
|
gemma_expert=self.gemma_expert,
|
|
)
|
|
|
|
# final norm
|
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
|
outputs_embeds = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
|
outputs_embeds.append(out_emb)
|
|
return outputs_embeds
|
|
|
|
# Apply gradient checkpointing to final norm if enabled
|
|
if use_gradient_checkpointing:
|
|
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_final_norms,
|
|
inputs_embeds,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
)
|
|
else:
|
|
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
|
|
prefix_output = outputs_embeds[0]
|
|
suffix_output = outputs_embeds[1]
|
|
prefix_past_key_values = None
|
|
|
|
return [prefix_output, suffix_output], prefix_past_key_values
|
|
|
|
|
|
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
"""Core PI05 PyTorch model."""
|
|
|
|
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.rtc_processor = rtc_processor
|
|
|
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
|
|
|
if config.image_resolution[0] != config.image_resolution[1]:
|
|
raise ValueError(
|
|
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
|
)
|
|
|
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
|
paligemma_config,
|
|
action_expert_config,
|
|
use_adarms=[False, True],
|
|
precision=config.dtype,
|
|
image_size=config.image_resolution[0],
|
|
freeze_vision_encoder=config.freeze_vision_encoder,
|
|
train_expert_only=config.train_expert_only,
|
|
)
|
|
|
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
|
self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
|
|
|
|
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
|
|
# Initialize gradient checkpointing flag
|
|
self.gradient_checkpointing_enabled = False
|
|
|
|
# Compile model if requested
|
|
if config.compile_model:
|
|
torch.set_float32_matmul_precision("high")
|
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
|
# Also compile the main forward pass used during training
|
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
"""Enable gradient checkpointing for memory optimization."""
|
|
self.gradient_checkpointing_enabled = True
|
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
|
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
|
|
|
def gradient_checkpointing_disable(self):
|
|
"""Disable gradient checkpointing."""
|
|
self.gradient_checkpointing_enabled = False
|
|
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
|
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
|
|
|
def _rtc_enabled(self):
|
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
|
|
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
|
"""Helper method to apply gradient checkpointing if enabled."""
|
|
if self.gradient_checkpointing_enabled and self.training:
|
|
return torch.utils.checkpoint.checkpoint(
|
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
def _prepare_attention_masks_4d(self, att_2d_masks):
|
|
"""Helper method to prepare 4D attention masks for transformer."""
|
|
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
|
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
|
|
|
def sample_noise(self, shape, device):
|
|
return torch.normal(
|
|
mean=0.0,
|
|
std=1.0,
|
|
size=shape,
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
|
|
def sample_time(self, bsize, device):
|
|
time_beta = sample_beta(
|
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
|
)
|
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
|
return time.to(dtype=torch.float32, device=device)
|
|
|
|
def embed_prefix(
|
|
self, images, img_masks, tokens, masks
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Embed images with SigLIP and language tokens with embedding layer."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
# Process images
|
|
for img, img_mask in zip(images, img_masks, strict=True):
|
|
|
|
def image_embed_func(img):
|
|
return self.paligemma_with_expert.embed_image(img)
|
|
|
|
img_emb = self._apply_checkpoint(image_embed_func, img)
|
|
bsize, num_img_embs = img_emb.shape[:2]
|
|
|
|
embs.append(img_emb)
|
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
|
att_masks += [0] * num_img_embs
|
|
|
|
# Process language tokens
|
|
def lang_embed_func(tokens):
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
|
lang_emb_dim = lang_emb.shape[-1]
|
|
return lang_emb * math.sqrt(lang_emb_dim)
|
|
|
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
|
embs.append(lang_emb)
|
|
pad_masks.append(masks)
|
|
|
|
num_lang_embs = lang_emb.shape[1]
|
|
att_masks += [0] * num_lang_embs
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
|
|
bsize = pad_masks.shape[0]
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks
|
|
|
|
def embed_suffix(self, noisy_actions, timestep):
|
|
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
# Embed timestep using sine-cosine positional encoding
|
|
time_emb = create_sinusoidal_pos_embedding(
|
|
timestep,
|
|
self.action_in_proj.out_features,
|
|
min_period=self.config.min_period,
|
|
max_period=self.config.max_period,
|
|
device=timestep.device,
|
|
)
|
|
time_emb = time_emb.type(dtype=timestep.dtype)
|
|
|
|
# Fuse timestep + action information using an MLP
|
|
def action_proj_func(noisy_actions):
|
|
return self.action_in_proj(noisy_actions)
|
|
|
|
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
|
|
|
def time_mlp_func(time_emb):
|
|
x = self.time_mlp_in(time_emb)
|
|
x = F.silu(x)
|
|
x = self.time_mlp_out(x)
|
|
return F.silu(x)
|
|
|
|
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
|
action_time_emb = action_emb
|
|
adarms_cond = time_emb
|
|
|
|
embs.append(action_time_emb)
|
|
bsize, action_time_dim = action_time_emb.shape[:2]
|
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
|
pad_masks.append(action_time_mask)
|
|
|
|
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
|
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks, adarms_cond
|
|
|
|
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
|
"""Do a full training forward pass and compute the loss."""
|
|
time_expanded = time[:, None, None]
|
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
u_t = noise - actions
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
|
|
|
if (
|
|
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
|
== torch.bfloat16
|
|
):
|
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
|
|
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
|
|
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
|
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
|
|
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
attention_mask=att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
return suffix_out
|
|
|
|
suffix_out = self._apply_checkpoint(
|
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
|
)
|
|
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
|
|
def action_out_proj_func(suffix_out):
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
|
|
|
return F.mse_loss(u_t, v_t, reduction="none")
|
|
|
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
|
def sample_actions(
|
|
self,
|
|
images,
|
|
img_masks,
|
|
tokens,
|
|
masks,
|
|
noise=None,
|
|
num_steps=None,
|
|
**kwargs: Unpack[ActionSelectKwargs],
|
|
) -> Tensor:
|
|
"""Do a full inference forward and compute the action."""
|
|
if num_steps is None:
|
|
num_steps = self.config.num_inference_steps
|
|
|
|
bsize = tokens.shape[0]
|
|
device = tokens.device
|
|
|
|
if noise is None:
|
|
# Sample noise with padded dimension as expected by action_in_proj
|
|
actions_shape = (
|
|
bsize,
|
|
self.config.chunk_size,
|
|
self.config.max_action_dim,
|
|
) # Use config max_action_dim for internal processing
|
|
noise = self.sample_noise(actions_shape, device)
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
|
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
|
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
_, past_key_values = self.paligemma_with_expert.forward(
|
|
attention_mask=prefix_att_2d_masks_4d,
|
|
position_ids=prefix_position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, None],
|
|
use_cache=True,
|
|
)
|
|
|
|
dt = -1.0 / num_steps
|
|
|
|
x_t = noise
|
|
for step in range(num_steps):
|
|
time = 1.0 + step * dt
|
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
|
|
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
|
return self.denoise_step(
|
|
prefix_pad_masks=prefix_pad_masks,
|
|
past_key_values=past_key_values,
|
|
x_t=input_x_t,
|
|
timestep=current_timestep,
|
|
)
|
|
|
|
if self._rtc_enabled():
|
|
inference_delay = kwargs.get("inference_delay")
|
|
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
|
execution_horizon = kwargs.get("execution_horizon")
|
|
|
|
v_t = self.rtc_processor.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=prev_chunk_left_over,
|
|
inference_delay=inference_delay,
|
|
time=time,
|
|
original_denoise_step_partial=denoise_step_partial_call,
|
|
execution_horizon=execution_horizon,
|
|
)
|
|
else:
|
|
v_t = denoise_step_partial_call(x_t)
|
|
|
|
x_t = x_t + dt * v_t
|
|
|
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
|
|
|
return x_t
|
|
|
|
def denoise_step(
|
|
self,
|
|
prefix_pad_masks,
|
|
past_key_values,
|
|
x_t,
|
|
timestep,
|
|
):
|
|
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
|
|
|
suffix_len = suffix_pad_masks.shape[1]
|
|
batch_size = prefix_pad_masks.shape[0]
|
|
prefix_len = prefix_pad_masks.shape[1]
|
|
|
|
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
|
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
|
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
|
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
past_key_values = copy.deepcopy(past_key_values)
|
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
attention_mask=full_att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=[None, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
|
|
suffix_out = outputs_embeds[1]
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
|
|
class PI05Policy(PreTrainedPolicy):
|
|
"""PI05 Policy for LeRobot."""
|
|
|
|
config_class = PI05Config
|
|
name = "pi05"
|
|
|
|
def __init__(
|
|
self,
|
|
config: PI05Config,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
config: Policy configuration class instance.
|
|
"""
|
|
require_package("transformers", extra="pi")
|
|
super().__init__(config)
|
|
config.validate_features()
|
|
self.config = config
|
|
|
|
# Initialize the core PI05 model
|
|
self.init_rtc_processor()
|
|
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
|
|
|
# Enable gradient checkpointing if requested
|
|
if config.gradient_checkpointing:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
self.model.to(config.device)
|
|
|
|
self.reset()
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: builtins.type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
config: PreTrainedConfig | None = None,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
strict: bool = True,
|
|
**kwargs,
|
|
) -> T:
|
|
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
|
print(
|
|
"The PI05 model is a direct port of the OpenPI implementation. \n"
|
|
"This implementation follows the original OpenPI structure for compatibility. \n"
|
|
"Original implementation: https://github.com/Physical-Intelligence/openpi"
|
|
)
|
|
if pretrained_name_or_path is None:
|
|
raise ValueError("pretrained_name_or_path is required")
|
|
|
|
# Use provided config if available, otherwise create default config
|
|
if config is None:
|
|
config = PreTrainedConfig.from_pretrained(
|
|
pretrained_name_or_path=pretrained_name_or_path,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
**kwargs,
|
|
)
|
|
|
|
# Initialize model without loading weights
|
|
# Check if dataset_stats were provided in kwargs
|
|
model = cls(config, **kwargs)
|
|
|
|
# Load state dict (expects keys with "model." prefix)
|
|
try:
|
|
print(f"Loading model from: {pretrained_name_or_path}")
|
|
try:
|
|
from transformers.utils import cached_file
|
|
|
|
resolved_file = cached_file(
|
|
pretrained_name_or_path,
|
|
"model.safetensors",
|
|
cache_dir=kwargs.get("cache_dir"),
|
|
force_download=kwargs.get("force_download", False),
|
|
resume_download=kwargs.get("resume_download"),
|
|
proxies=kwargs.get("proxies"),
|
|
token=kwargs.get("token"),
|
|
revision=kwargs.get("revision"),
|
|
local_files_only=kwargs.get("local_files_only", False),
|
|
)
|
|
from safetensors.torch import load_file
|
|
|
|
original_state_dict = load_file(resolved_file)
|
|
print("✓ Loaded state dict from model.safetensors")
|
|
except Exception as e:
|
|
print(f"Could not load state dict from remote files: {e}")
|
|
print("Returning model without loading pretrained weights")
|
|
return model
|
|
|
|
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
|
|
|
# Then add "model." prefix for all keys that don't already have it
|
|
remapped_state_dict = {}
|
|
remap_count = 0
|
|
|
|
for key, value in fixed_state_dict.items():
|
|
if not key.startswith("model."):
|
|
new_key = f"model.{key}"
|
|
remapped_state_dict[new_key] = value
|
|
remap_count += 1
|
|
else:
|
|
remapped_state_dict[key] = value
|
|
|
|
if remap_count > 0:
|
|
print(f"Remapped {remap_count} state dict keys")
|
|
|
|
# Load the remapped state dict into the model
|
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
|
|
|
if missing_keys:
|
|
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
if len(missing_keys) <= 5:
|
|
for key in missing_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in missing_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(missing_keys) - 5} more")
|
|
|
|
if unexpected_keys:
|
|
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
if len(unexpected_keys) <= 5:
|
|
for key in unexpected_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in unexpected_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(unexpected_keys) - 5} more")
|
|
|
|
if not missing_keys and not unexpected_keys:
|
|
print("All keys loaded successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not load state dict: {e}")
|
|
|
|
return model
|
|
|
|
def _fix_pytorch_state_dict_keys(
|
|
self, state_dict, model_config
|
|
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
|
"""Fix state dict keys to match current model architecture."""
|
|
import re
|
|
|
|
fixed_state_dict = {}
|
|
|
|
for key, value in state_dict.items():
|
|
new_key = key
|
|
|
|
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
|
# For gemma expert layers
|
|
if re.match(
|
|
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
|
key,
|
|
):
|
|
# Check if the model actually has adaRMS enabled for the expert
|
|
expert_uses_adarms = getattr(
|
|
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
)
|
|
if expert_uses_adarms:
|
|
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
|
|
continue
|
|
|
|
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
|
# Check if the model actually has adaRMS enabled for the expert
|
|
expert_uses_adarms = getattr(
|
|
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
)
|
|
if expert_uses_adarms:
|
|
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
|
|
continue
|
|
|
|
# Handle MLP naming changes for pi05
|
|
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
|
|
if key.startswith("action_time_mlp_in."):
|
|
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
|
|
elif key.startswith("action_time_mlp_out."):
|
|
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
|
|
# Also handle state_proj which shouldn't exist in pi05
|
|
if key.startswith("state_proj."):
|
|
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
|
|
continue
|
|
|
|
# Handle vision tower embedding layer potential differences
|
|
if "patch_embedding" in key:
|
|
# Some checkpoints might have this, but current model expects different structure
|
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
|
|
|
if (
|
|
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
|
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
|
):
|
|
fixed_state_dict[
|
|
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
|
] = value.clone()
|
|
|
|
fixed_state_dict[new_key] = value
|
|
|
|
return fixed_state_dict
|
|
|
|
def get_optim_params(self) -> dict:
|
|
return self.parameters()
|
|
|
|
def reset(self):
|
|
"""Reset internal state - called when environment resets."""
|
|
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
|
self._queues = {
|
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
|
}
|
|
|
|
def init_rtc_processor(self):
|
|
"""Initialize RTC processor if RTC is enabled in config."""
|
|
self.rtc_processor = None
|
|
|
|
# Create processor if config provided
|
|
# If RTC is not enabled - we can still track the denoising data
|
|
if self.config.rtc_config is not None:
|
|
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
|
|
|
model_value = getattr(self, "model", None)
|
|
if model_value is not None:
|
|
model_value.rtc_processor = self.rtc_processor
|
|
|
|
def _rtc_enabled(self) -> bool:
|
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
|
|
|
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
|
"""Preprocess images for the model.
|
|
|
|
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
|
|
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
|
|
"""
|
|
images = []
|
|
img_masks = []
|
|
|
|
# Get device from model parameters
|
|
device = next(self.parameters()).device
|
|
|
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
|
|
if len(present_img_keys) == 0:
|
|
raise ValueError(
|
|
f"All image features are missing from the batch. At least one expected. "
|
|
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
|
)
|
|
|
|
# Preprocess image features present in the batch
|
|
for key in present_img_keys:
|
|
img = batch[key]
|
|
|
|
# Ensure tensor is on the same device as the model
|
|
if img.device != device:
|
|
img = img.to(device)
|
|
|
|
# Ensure float32 dtype for consistency
|
|
if img.dtype != torch.float32:
|
|
img = img.to(torch.float32)
|
|
|
|
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
|
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
|
|
|
if is_channels_first:
|
|
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
|
img = img.permute(0, 2, 3, 1)
|
|
|
|
# from openpi preprocess_observation_pytorch: Resize with padding if needed
|
|
if img.shape[1:3] != self.config.image_resolution:
|
|
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
|
|
|
# Normalize from [0,1] to [-1,1] as expected by siglip
|
|
img = img * 2.0 - 1.0
|
|
|
|
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
|
if is_channels_first:
|
|
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
|
|
|
images.append(img)
|
|
# Create mask (all ones for real images)
|
|
bsize = img.shape[0]
|
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
img_masks.append(mask)
|
|
|
|
# Create image features not present in the batch as fully 0 padded images
|
|
for _num_empty_cameras in range(len(missing_img_keys)):
|
|
img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
|
|
mask = torch.zeros_like(mask) # Mask is zero for empty cameras
|
|
images.append(img)
|
|
img_masks.append(mask)
|
|
|
|
return images, img_masks
|
|
|
|
def prepare_action(self, batch):
|
|
"""Pad action"""
|
|
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
|
return actions
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Select a single action given environment observations."""
|
|
assert not self._rtc_enabled(), (
|
|
"RTC is not supported for select_action, use it with predict_action_chunk"
|
|
)
|
|
|
|
self.eval()
|
|
|
|
# Action queue logic for n_action_steps > 1
|
|
if len(self._action_queue) == 0:
|
|
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
|
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
|
self._action_queue.extend(actions.transpose(0, 1))
|
|
|
|
return self._action_queue.popleft()
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
|
"""Predict a chunk of actions given environment observations."""
|
|
self.eval()
|
|
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
|
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
|
|
|
# Unpad actions to actual action dimension
|
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
|
actions = actions[:, :, :original_action_dim]
|
|
|
|
return actions
|
|
|
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
|
"""Run the batch through the model and compute the loss for training.
|
|
|
|
Args:
|
|
batch: Training batch containing observations and actions.
|
|
reduction: How to reduce the loss. Options:
|
|
- "mean": Return scalar mean loss (default, backward compatible)
|
|
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
|
"""
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
actions = self.prepare_action(batch)
|
|
|
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
|
|
|
# Compute loss (no separate state needed for PI05)
|
|
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
|
|
|
# Truncate losses to actual action dimensions
|
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
|
losses = losses[:, :, :original_action_dim]
|
|
|
|
loss_dict = {
|
|
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
|
}
|
|
|
|
if reduction == "none":
|
|
# Return per-sample losses (B,) by averaging over time and action dims
|
|
per_sample_loss = losses.mean(dim=(1, 2))
|
|
loss_dict["loss"] = per_sample_loss.mean().item()
|
|
return per_sample_loss, loss_dict
|
|
else:
|
|
# Default: return scalar mean loss
|
|
loss = losses.mean()
|
|
loss_dict["loss"] = loss.item()
|
|
return loss, loss_dict
|
|
|
|
def _get_default_peft_targets(self) -> dict[str, any]:
|
|
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
|
common_projections = (
|
|
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
|
)
|
|
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
|
return {
|
|
"target_modules": target_modules,
|
|
"modules_to_save": [],
|
|
}
|