mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
abde7be3b3
* initial commit * change device in test * do detailed import * adhere to python 3.11 syntax * fix autodocstring * additionally * do same in other files * add model. prefix to all keys in state dict * use dummy stats * add pi05 * also shorten action_steps * fix test * all test pass! and fix tokenizer max length between 05 and 0 * remove test * fix transformer dependency * fix test * split pi0 and pi05 policy in seperate files * fix test * fix push to hub test * add some comments, license and readme * remove warning in config * add pi05 to factory * remove check * rename action_horizon to chunk_size * clean up padding of state and action (more in line with lerobot pi0) * add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0 * fix key match from pytorch state dict (similar keys to openpi implementation now) * also for pi05 * update to python 3.11 * revert to openpi transformer replace python 3.11 * fix(modeling pi0): nit warning message * use safeauto_docstring * fix: remove unused param * fix from pretrained * add preprocess tests * also compile forward method * Do not add model prefix to normalization * use same name for action and state dim as lerobot pi0 and remove fixed image keys * load from pretrained_path * temp: hardcode base model * fix override self.pretrained_path = None overwrite * rename to loss * remove additional image augmentations, lerobot dataset already does this * Add docs * put tests in test folder * Add test to instatiate all base models * go back to python 3.10 * update docs * adapt docs pi05 * change docs: finetune base model options * minor docs fixes and dependencies * remove todo * cast float64 to float32 for mps * skip if no transformers * fix tests * add new models to modelcard * add back init * fix circular input * feat: only run pi test on GPU * remove require_nightly_gpu * replace decorator test_pi0_openpi * rename action_dim, state_dim to max_action_dim, max_state_dim * fix doc and constants * cleanup tests * fix from pretrained * fix tests * add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests * fix, state is included in language not in flow head * Move test to specific folder * and paligemma task with newline * remove add_special_tokens, not needed * feedback pr * Remove previous pi0 and rename pi0_openpi and pi05_openpi * Add Quantile stats to LeRobotDataset (#1985) * - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin * rename pi0/pi05 files * Remove open pi patch and use custom transformer branch for now * renaming * fix * Revert "fix" This reverts commit1ea65730ac. * fix naming * feet(pi0/pi0.5): add pipeline (#2009) * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * refactor(pi05): update imports and rename configuration classes - Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase. * update(pi05): increase tokenizer_max_length for improved processing - Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences. - This adjustment aims to improve the overall performance and flexibility of the PI05 configuration. * add default for state (max_state_dim) * correct naming * fix import * cleanup code * remove unused test * us quantiles for action * move to device * remove discrete state assert * fix pi05 test * move pi05 to device * use base models in comparison tests * small renames for tests * change number of tokens pi05 test * fix openpi tokenization in test * fix hub test * fix test * assert lerobot vs openpi tests --------- Co-authored-by: Pepijn <pepijn@huggingface.co> * add headers * add back previously removed imports * update if statement load processor with dataset stats * remove to avoid circular import * inject dataset stats for pretrained models * check normalization before applying * add link to quantile augument script * fix(policies): transformers import for ci in PI0 & PI05 (#2039) * fix(policies): transformers import for ci in PI0 * fix(policies): transformers import for ci in PI05 * test(processor): fix expected raise when normalization types are missing (#2040) * switch normalization order pipeline for pi05 * Fix/quantiles script (#2064) * refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment * overwrite the pipeline normalization features with the policy features * remove double normalization overwrite * cleanup from pretrained * remove typo * also set norm_map * fix(augment_quantiles) images incorrectly divided by 255 * clamp quantiles * link to lerobot base models * rename tests * encorperate PR feedback * update docstring for RunningQuantileStats * update doc links * Revert "clamp quantiles" This reverts commit172207471c. * fix self.paligemma * fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1] * fix libero doc and use different transformer branch * use fix branch instead of feat * update results libero * add new line * fix formatting * precommit * update results libero * update libero doc * update title * final changes * add quantiles to test * run pre commit --------- Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
1164 lines
46 KiB
Python
1164 lines
46 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import builtins
|
|
import logging
|
|
import math
|
|
from collections import deque
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Literal
|
|
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.utils.import_utils import _transformers_available
|
|
|
|
# Conditional import for type checking and lazy loading
|
|
if TYPE_CHECKING or _transformers_available:
|
|
from transformers.models.auto import CONFIG_MAPPING
|
|
from transformers.models.gemma import modeling_gemma
|
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
else:
|
|
CONFIG_MAPPING = None
|
|
modeling_gemma = None
|
|
GemmaForCausalLM = None
|
|
PaliGemmaForConditionalGeneration = None
|
|
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
|
from lerobot.utils.constants import (
|
|
ACTION,
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
|
OBS_LANGUAGE_TOKENS,
|
|
OPENPI_ATTENTION_MASK_VALUE,
|
|
)
|
|
|
|
|
|
def get_safe_dtype(target_dtype, device_type):
|
|
"""Get a safe dtype for the given device type."""
|
|
if device_type == "mps" and target_dtype == torch.float64:
|
|
return torch.float32
|
|
if device_type == "cpu":
|
|
# CPU doesn't support bfloat16, use float32 instead
|
|
if target_dtype == torch.bfloat16:
|
|
return torch.float32
|
|
if target_dtype == torch.float64:
|
|
return torch.float64
|
|
return target_dtype
|
|
|
|
|
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
|
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
|
) -> Tensor:
|
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
|
if dimension % 2 != 0:
|
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
|
|
if time.ndim != 1:
|
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
|
|
dtype = get_safe_dtype(torch.float64, device.type)
|
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
period = min_period * (max_period / min_period) ** fraction
|
|
|
|
# Compute the outer product
|
|
scaling_factor = 1.0 / period * 2 * math.pi
|
|
sin_input = scaling_factor[None, :] * time[:, None]
|
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
|
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
|
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
|
return dist.sample((bsize,))
|
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
|
"""Copied from big_vision.
|
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
setup several types of attention, for example:
|
|
|
|
[[1 1 1 1 1 1]]: pure causal attention.
|
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
themselves and the last 3 tokens have a causal attention. The first
|
|
entry could also be a 1 without changing behaviour.
|
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
block can attend all previous blocks and all tokens on the same block.
|
|
|
|
Args:
|
|
input_mask: bool[B, N] true if its part of the input, false if padding.
|
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
it and 0 where it shares the same attention mask as the previous token.
|
|
"""
|
|
if att_masks.ndim != 2:
|
|
raise ValueError(att_masks.ndim)
|
|
if pad_masks.ndim != 2:
|
|
raise ValueError(pad_masks.ndim)
|
|
|
|
cumsum = torch.cumsum(att_masks, dim=1)
|
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
return att_2d_masks & pad_2d_masks
|
|
|
|
|
|
def pad_vector(vector, new_dim):
|
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
|
|
|
Can be (batch_size x sequence_length x features_dimension)
|
|
or (batch_size x features_dimension)
|
|
"""
|
|
if vector.shape[-1] >= new_dim:
|
|
return vector
|
|
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
|
|
|
|
|
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|
images: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
mode: str = "bilinear",
|
|
) -> torch.Tensor:
|
|
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
|
|
Args:
|
|
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
|
height: Target height
|
|
width: Target width
|
|
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
|
|
Returns:
|
|
Resized and padded tensor with same shape format as input
|
|
"""
|
|
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
|
if images.shape[-1] <= 4: # Assume channels-last format
|
|
channels_last = True
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
|
else:
|
|
channels_last = False
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
|
|
batch_size, channels, cur_height, cur_width = images.shape
|
|
|
|
# Calculate resize ratio
|
|
ratio = max(cur_width / width, cur_height / height)
|
|
resized_height = int(cur_height / ratio)
|
|
resized_width = int(cur_width / ratio)
|
|
|
|
# Resize
|
|
resized_images = F.interpolate(
|
|
images,
|
|
size=(resized_height, resized_width),
|
|
mode=mode,
|
|
align_corners=False if mode == "bilinear" else None,
|
|
)
|
|
|
|
# Handle dtype-specific clipping
|
|
if images.dtype == torch.uint8:
|
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
|
elif images.dtype == torch.float32:
|
|
resized_images = resized_images.clamp(-1.0, 1.0)
|
|
else:
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
|
# Calculate padding
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
|
pad_h1 = pad_h0 + remainder_h
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
|
pad_w1 = pad_w0 + remainder_w
|
|
|
|
# Pad
|
|
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
|
padded_images = F.pad(
|
|
resized_images,
|
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
|
mode="constant",
|
|
value=constant_value,
|
|
)
|
|
|
|
# Convert back to original format if needed
|
|
if channels_last:
|
|
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
|
|
return padded_images
|
|
|
|
|
|
# Define the complete layer computation function for gradient checkpointing
|
|
def compute_layer_complete(
|
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
|
):
|
|
models = [paligemma.language_model, gemma_expert.model]
|
|
query_states = []
|
|
key_states = []
|
|
value_states = []
|
|
gates = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
|
gates.append(gate)
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
query_states.append(query_state)
|
|
key_states.append(key_state)
|
|
value_states.append(value_state)
|
|
# Concatenate and process attention
|
|
query_states = torch.cat(query_states, dim=2)
|
|
key_states = torch.cat(key_states, dim=2)
|
|
value_states = torch.cat(value_states, dim=2)
|
|
dummy_tensor = torch.zeros(
|
|
query_states.shape[0],
|
|
query_states.shape[2],
|
|
query_states.shape[-1],
|
|
device=query_states.device,
|
|
dtype=query_states.dtype,
|
|
)
|
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
|
)
|
|
batch_size = query_states.shape[0]
|
|
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
|
# Attention computation
|
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
|
paligemma.language_model.layers[layer_idx].self_attn,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
scaling,
|
|
)
|
|
# Get head_dim from the current layer, not from the model
|
|
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
|
# Process layer outputs
|
|
outputs_embeds = []
|
|
start_pos = 0
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
end_pos = start_pos + hidden_states.shape[1]
|
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
|
# first residual
|
|
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
|
after_first_residual = out_emb.clone()
|
|
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
|
out_emb = layer.mlp(out_emb)
|
|
# second residual
|
|
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
|
outputs_embeds.append(out_emb)
|
|
start_pos = end_pos
|
|
return outputs_embeds
|
|
|
|
|
|
class GemmaConfig: # see openpi `gemma.py: Config`
|
|
"""Configuration for Gemma model variants."""
|
|
|
|
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
|
self.width = width
|
|
self.depth = depth
|
|
self.mlp_dim = mlp_dim
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.head_dim = head_dim
|
|
|
|
|
|
def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
|
|
"""Returns config for specified gemma variant."""
|
|
if variant == "gemma_300m":
|
|
return GemmaConfig(
|
|
width=1024,
|
|
depth=18,
|
|
mlp_dim=4096,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
elif variant == "gemma_2b":
|
|
return GemmaConfig(
|
|
width=2048,
|
|
depth=18,
|
|
mlp_dim=16_384,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown variant: {variant}")
|
|
|
|
|
|
class PaliGemmaWithExpertModel(
|
|
nn.Module
|
|
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
|
"""PaliGemma model with action expert for PI05."""
|
|
|
|
def __init__(
|
|
self,
|
|
vlm_config,
|
|
action_expert_config,
|
|
use_adarms=None,
|
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
|
):
|
|
if use_adarms is None:
|
|
use_adarms = [False, False]
|
|
super().__init__()
|
|
|
|
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
|
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
|
vlm_config_hf.image_token_index = 257152
|
|
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
|
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
|
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
|
vlm_config_hf.text_config.torch_dtype = "float32"
|
|
vlm_config_hf.text_config.vocab_size = 257152
|
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
|
vlm_config_hf.vision_config.projection_dim = 2048
|
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
|
vlm_config_hf.vision_config.torch_dtype = "float32"
|
|
|
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
|
head_dim=action_expert_config.head_dim,
|
|
hidden_size=action_expert_config.width,
|
|
intermediate_size=action_expert_config.mlp_dim,
|
|
num_attention_heads=action_expert_config.num_heads,
|
|
num_hidden_layers=action_expert_config.depth,
|
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
|
vocab_size=257152,
|
|
hidden_activation="gelu_pytorch_tanh",
|
|
torch_dtype="float32",
|
|
use_adarms=use_adarms[1],
|
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
|
)
|
|
|
|
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
|
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
|
self.gemma_expert.model.embed_tokens = None
|
|
|
|
self.to_bfloat16_for_selected_params(precision)
|
|
|
|
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
|
if precision == "bfloat16":
|
|
self.to(dtype=torch.bfloat16)
|
|
elif precision == "float32":
|
|
self.to(dtype=torch.float32)
|
|
return
|
|
else:
|
|
raise ValueError(f"Invalid precision: {precision}")
|
|
|
|
params_to_keep_float32 = [
|
|
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
|
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
|
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
"input_layernorm",
|
|
"post_attention_layernorm",
|
|
"model.norm",
|
|
]
|
|
|
|
for name, param in self.named_parameters():
|
|
if any(selector in name for selector in params_to_keep_float32):
|
|
param.data = param.data.to(dtype=torch.float32)
|
|
|
|
def embed_image(self, image: torch.Tensor):
|
|
return self.paligemma.model.get_image_features(image)
|
|
|
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
|
return self.paligemma.language_model.embed_tokens(tokens)
|
|
|
|
def forward(
|
|
self,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: list[torch.FloatTensor] | None = None,
|
|
inputs_embeds: list[torch.FloatTensor] | None = None,
|
|
use_cache: bool | None = None,
|
|
adarms_cond: list[torch.Tensor] | None = None,
|
|
):
|
|
if adarms_cond is None:
|
|
adarms_cond = [None, None]
|
|
if inputs_embeds[1] is None:
|
|
prefix_output = self.paligemma.language_model.forward(
|
|
inputs_embeds=inputs_embeds[0],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
|
)
|
|
prefix_past_key_values = prefix_output.past_key_values
|
|
prefix_output = prefix_output.last_hidden_state
|
|
suffix_output = None
|
|
elif inputs_embeds[0] is None:
|
|
suffix_output = self.gemma_expert.model.forward(
|
|
inputs_embeds=inputs_embeds[1],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
|
)
|
|
suffix_output = suffix_output.last_hidden_state
|
|
prefix_output = None
|
|
prefix_past_key_values = None
|
|
else:
|
|
models = [self.paligemma.language_model, self.gemma_expert.model]
|
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
|
|
# Check if gradient checkpointing is enabled for any of the models
|
|
use_gradient_checkpointing = (
|
|
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
|
and self.gemma_expert.model.gradient_checkpointing
|
|
and self.training
|
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
|
|
|
# Process all layers with gradient checkpointing if enabled
|
|
for layer_idx in range(num_layers):
|
|
if use_gradient_checkpointing:
|
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_layer_complete,
|
|
layer_idx,
|
|
inputs_embeds,
|
|
attention_mask,
|
|
position_ids,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
paligemma=self.paligemma,
|
|
gemma_expert=self.gemma_expert,
|
|
)
|
|
else:
|
|
inputs_embeds = compute_layer_complete(
|
|
layer_idx,
|
|
inputs_embeds,
|
|
attention_mask,
|
|
position_ids,
|
|
adarms_cond,
|
|
paligemma=self.paligemma,
|
|
gemma_expert=self.gemma_expert,
|
|
)
|
|
|
|
# final norm
|
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
|
outputs_embeds = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
|
outputs_embeds.append(out_emb)
|
|
return outputs_embeds
|
|
|
|
# Apply gradient checkpointing to final norm if enabled
|
|
if use_gradient_checkpointing:
|
|
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_final_norms,
|
|
inputs_embeds,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
)
|
|
else:
|
|
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
|
|
prefix_output = outputs_embeds[0]
|
|
suffix_output = outputs_embeds[1]
|
|
prefix_past_key_values = None
|
|
|
|
return [prefix_output, suffix_output], prefix_past_key_values
|
|
|
|
|
|
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
"""Core PI05 PyTorch model."""
|
|
|
|
def __init__(self, config: PI05Config):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
|
|
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
|
paligemma_config,
|
|
action_expert_config,
|
|
use_adarms=[False, True],
|
|
precision=config.dtype,
|
|
)
|
|
|
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
|
self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
|
|
|
|
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
|
|
# Initialize gradient checkpointing flag
|
|
self.gradient_checkpointing_enabled = False
|
|
|
|
# Compile model if requested
|
|
if config.compile_model:
|
|
torch.set_float32_matmul_precision("high")
|
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
|
|
|
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
|
|
try:
|
|
from transformers.models.siglip import check
|
|
|
|
if not check.check_whether_transformers_replace_is_installed_correctly():
|
|
raise ValueError(msg)
|
|
except ImportError:
|
|
raise ValueError(msg) from None
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
"""Enable gradient checkpointing for memory optimization."""
|
|
self.gradient_checkpointing_enabled = True
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
|
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
|
|
|
def gradient_checkpointing_disable(self):
|
|
"""Disable gradient checkpointing."""
|
|
self.gradient_checkpointing_enabled = False
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
|
|
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
|
"""Helper method to apply gradient checkpointing if enabled."""
|
|
if self.gradient_checkpointing_enabled and self.training:
|
|
return torch.utils.checkpoint.checkpoint(
|
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
def _prepare_attention_masks_4d(self, att_2d_masks):
|
|
"""Helper method to prepare 4D attention masks for transformer."""
|
|
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
|
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
|
|
|
def sample_noise(self, shape, device):
|
|
return torch.normal(
|
|
mean=0.0,
|
|
std=1.0,
|
|
size=shape,
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
|
|
def sample_time(self, bsize, device):
|
|
time_beta = sample_beta(
|
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
|
)
|
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
|
return time.to(dtype=torch.float32, device=device)
|
|
|
|
def embed_prefix(
|
|
self, images, img_masks, tokens, masks
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Embed images with SigLIP and language tokens with embedding layer."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
# Process images
|
|
for img, img_mask in zip(images, img_masks, strict=True):
|
|
|
|
def image_embed_func(img):
|
|
return self.paligemma_with_expert.embed_image(img)
|
|
|
|
img_emb = self._apply_checkpoint(image_embed_func, img)
|
|
bsize, num_img_embs = img_emb.shape[:2]
|
|
|
|
embs.append(img_emb)
|
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
|
att_masks += [0] * num_img_embs
|
|
|
|
# Process language tokens
|
|
def lang_embed_func(tokens):
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
|
lang_emb_dim = lang_emb.shape[-1]
|
|
return lang_emb * math.sqrt(lang_emb_dim)
|
|
|
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
|
embs.append(lang_emb)
|
|
pad_masks.append(masks)
|
|
|
|
num_lang_embs = lang_emb.shape[1]
|
|
att_masks += [0] * num_lang_embs
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
|
|
bsize = pad_masks.shape[0]
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks
|
|
|
|
def embed_suffix(self, noisy_actions, timestep):
|
|
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
# Embed timestep using sine-cosine positional encoding
|
|
time_emb = create_sinusoidal_pos_embedding(
|
|
timestep,
|
|
self.action_in_proj.out_features,
|
|
min_period=self.config.min_period,
|
|
max_period=self.config.max_period,
|
|
device=timestep.device,
|
|
)
|
|
time_emb = time_emb.type(dtype=timestep.dtype)
|
|
|
|
# Fuse timestep + action information using an MLP
|
|
def action_proj_func(noisy_actions):
|
|
return self.action_in_proj(noisy_actions)
|
|
|
|
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
|
|
|
def time_mlp_func(time_emb):
|
|
x = self.time_mlp_in(time_emb)
|
|
x = F.silu(x)
|
|
x = self.time_mlp_out(x)
|
|
return F.silu(x)
|
|
|
|
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
|
action_time_emb = action_emb
|
|
adarms_cond = time_emb
|
|
|
|
embs.append(action_time_emb)
|
|
bsize, action_time_dim = action_time_emb.shape[:2]
|
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
|
pad_masks.append(action_time_mask)
|
|
|
|
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
|
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks, adarms_cond
|
|
|
|
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
|
"""Do a full training forward pass and compute the loss."""
|
|
if noise is None:
|
|
noise = self.sample_noise(actions.shape, actions.device)
|
|
|
|
if time is None:
|
|
time = self.sample_time(actions.shape[0], actions.device)
|
|
|
|
time_expanded = time[:, None, None]
|
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
u_t = noise - actions
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
|
|
|
if (
|
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
|
== torch.bfloat16
|
|
):
|
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
|
|
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
|
|
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
|
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
|
|
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
attention_mask=att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
return suffix_out
|
|
|
|
suffix_out = self._apply_checkpoint(
|
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
|
)
|
|
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
|
|
def action_out_proj_func(suffix_out):
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
|
|
|
return F.mse_loss(u_t, v_t, reduction="none")
|
|
|
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
|
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
|
"""Do a full inference forward and compute the action."""
|
|
if num_steps is None:
|
|
num_steps = self.config.num_inference_steps
|
|
|
|
bsize = tokens.shape[0]
|
|
device = tokens.device
|
|
|
|
if noise is None:
|
|
# Sample noise with padded dimension as expected by action_in_proj
|
|
actions_shape = (
|
|
bsize,
|
|
self.config.chunk_size,
|
|
self.config.max_action_dim,
|
|
) # Use config max_action_dim for internal processing
|
|
noise = self.sample_noise(actions_shape, device)
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
|
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
|
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
_, past_key_values = self.paligemma_with_expert.forward(
|
|
attention_mask=prefix_att_2d_masks_4d,
|
|
position_ids=prefix_position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, None],
|
|
use_cache=True,
|
|
)
|
|
|
|
dt = -1.0 / num_steps
|
|
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
|
|
x_t = noise
|
|
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
while time >= -dt / 2:
|
|
expanded_time = time.expand(bsize)
|
|
v_t = self.denoise_step(
|
|
prefix_pad_masks,
|
|
past_key_values,
|
|
x_t,
|
|
expanded_time,
|
|
)
|
|
x_t = x_t + dt * v_t
|
|
time += dt
|
|
|
|
return x_t
|
|
|
|
def denoise_step(
|
|
self,
|
|
prefix_pad_masks,
|
|
past_key_values,
|
|
x_t,
|
|
timestep,
|
|
):
|
|
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
|
|
|
suffix_len = suffix_pad_masks.shape[1]
|
|
batch_size = prefix_pad_masks.shape[0]
|
|
prefix_len = prefix_pad_masks.shape[1]
|
|
|
|
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
|
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
|
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
|
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
attention_mask=full_att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=[None, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
|
|
suffix_out = outputs_embeds[1]
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
|
|
class PI05Policy(PreTrainedPolicy):
|
|
"""PI05 Policy for LeRobot."""
|
|
|
|
config_class = PI05Config
|
|
name = "pi05"
|
|
|
|
def __init__(
|
|
self,
|
|
config: PI05Config,
|
|
):
|
|
"""
|
|
Args:
|
|
config: Policy configuration class instance.
|
|
"""
|
|
super().__init__(config)
|
|
config.validate_features()
|
|
self.config = config
|
|
|
|
# Initialize the core PI05 model
|
|
self.model = PI05Pytorch(config)
|
|
|
|
# Enable gradient checkpointing if requested
|
|
if config.gradient_checkpointing:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
self.model.to(config.device)
|
|
|
|
self.reset()
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: builtins.type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
config: PreTrainedConfig | None = None,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
strict: bool = True,
|
|
**kwargs,
|
|
) -> T:
|
|
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
|
print(
|
|
"The PI05 model is a direct port of the OpenPI implementation. \n"
|
|
"This implementation follows the original OpenPI structure for compatibility. \n"
|
|
"Original implementation: https://github.com/Physical-Intelligence/openpi"
|
|
)
|
|
if pretrained_name_or_path is None:
|
|
raise ValueError("pretrained_name_or_path is required")
|
|
|
|
# Use provided config if available, otherwise create default config
|
|
if config is None:
|
|
config = PreTrainedConfig.from_pretrained(
|
|
pretrained_name_or_path=pretrained_name_or_path,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
**kwargs,
|
|
)
|
|
|
|
# Initialize model without loading weights
|
|
# Check if dataset_stats were provided in kwargs
|
|
model = cls(config, **kwargs)
|
|
|
|
# Now manually load and remap the state dict
|
|
try:
|
|
# Try to load the pytorch_model.bin or model.safetensors file
|
|
print(f"Loading model from: {pretrained_name_or_path}")
|
|
try:
|
|
from transformers.utils import cached_file
|
|
|
|
# Try safetensors first
|
|
resolved_file = cached_file(
|
|
pretrained_name_or_path,
|
|
"model.safetensors",
|
|
cache_dir=kwargs.get("cache_dir"),
|
|
force_download=kwargs.get("force_download", False),
|
|
resume_download=kwargs.get("resume_download"),
|
|
proxies=kwargs.get("proxies"),
|
|
use_auth_token=kwargs.get("use_auth_token"),
|
|
revision=kwargs.get("revision"),
|
|
local_files_only=kwargs.get("local_files_only", False),
|
|
)
|
|
from safetensors.torch import load_file
|
|
|
|
original_state_dict = load_file(resolved_file)
|
|
print("✓ Loaded state dict from model.safetensors")
|
|
except Exception as e:
|
|
print(f"Could not load state dict from remote files: {e}")
|
|
print("Returning model without loading pretrained weights")
|
|
return model
|
|
|
|
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
|
|
|
# Then add "model." prefix for all keys that don't already have it
|
|
remapped_state_dict = {}
|
|
remap_count = 0
|
|
|
|
for key, value in fixed_state_dict.items():
|
|
if not key.startswith("model."):
|
|
new_key = f"model.{key}"
|
|
remapped_state_dict[new_key] = value
|
|
remap_count += 1
|
|
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
print(f"Remapped: {key} -> {new_key}")
|
|
else:
|
|
remapped_state_dict[key] = value
|
|
|
|
if remap_count > 0:
|
|
print(f"Remapped {remap_count} state dict keys")
|
|
|
|
# Load the remapped state dict into the model
|
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
|
|
|
if missing_keys:
|
|
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
if len(missing_keys) <= 5:
|
|
for key in missing_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in missing_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(missing_keys) - 5} more")
|
|
|
|
if unexpected_keys:
|
|
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
if len(unexpected_keys) <= 5:
|
|
for key in unexpected_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in unexpected_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(unexpected_keys) - 5} more")
|
|
|
|
if not missing_keys and not unexpected_keys:
|
|
print("All keys loaded successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not remap state dict keys: {e}")
|
|
|
|
return model
|
|
|
|
def _fix_pytorch_state_dict_keys(
|
|
self, state_dict, model_config
|
|
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
|
"""Fix state dict keys to match current model architecture."""
|
|
import re
|
|
|
|
fixed_state_dict = {}
|
|
|
|
for key, value in state_dict.items():
|
|
new_key = key
|
|
|
|
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
|
# For gemma expert layers
|
|
if re.match(
|
|
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
|
key,
|
|
):
|
|
# Check if the model actually has adaRMS enabled for the expert
|
|
expert_uses_adarms = getattr(
|
|
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
)
|
|
if expert_uses_adarms:
|
|
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
|
|
continue
|
|
|
|
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
|
# Check if the model actually has adaRMS enabled for the expert
|
|
expert_uses_adarms = getattr(
|
|
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
)
|
|
if expert_uses_adarms:
|
|
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
|
|
continue
|
|
|
|
# Handle MLP naming changes for pi05
|
|
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
|
|
if key.startswith("action_time_mlp_in."):
|
|
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
|
|
elif key.startswith("action_time_mlp_out."):
|
|
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
|
|
# Also handle state_proj which shouldn't exist in pi05
|
|
if key.startswith("state_proj."):
|
|
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
|
|
continue
|
|
|
|
# Handle vision tower embedding layer potential differences
|
|
if "patch_embedding" in key:
|
|
# Some checkpoints might have this, but current model expects different structure
|
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
|
|
|
fixed_state_dict[new_key] = value
|
|
|
|
return fixed_state_dict
|
|
|
|
def get_optim_params(self) -> dict:
|
|
return self.parameters()
|
|
|
|
def reset(self):
|
|
"""Reset internal state - called when environment resets."""
|
|
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
|
self._queues = {
|
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
|
}
|
|
|
|
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
|
"""Preprocess images for the model.
|
|
|
|
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
|
|
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
|
|
"""
|
|
images = []
|
|
img_masks = []
|
|
|
|
# Get device from model parameters
|
|
device = next(self.parameters()).device
|
|
|
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
|
|
if len(present_img_keys) == 0:
|
|
raise ValueError(
|
|
f"All image features are missing from the batch. At least one expected. "
|
|
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
|
|
)
|
|
|
|
# Preprocess image features present in the batch
|
|
for key in present_img_keys:
|
|
img = batch[key]
|
|
|
|
# Ensure tensor is on the same device as the model
|
|
if img.device != device:
|
|
img = img.to(device)
|
|
|
|
# Ensure float32 dtype for consistency
|
|
if img.dtype != torch.float32:
|
|
img = img.to(torch.float32)
|
|
|
|
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
|
|
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
|
|
|
|
if is_channels_first:
|
|
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
|
img = img.permute(0, 2, 3, 1)
|
|
|
|
# from openpi preprocess_observation_pytorch: Resize with padding if needed
|
|
if img.shape[1:3] != self.config.image_resolution:
|
|
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
|
|
|
# Normalize from [0,1] to [-1,1] as expected by siglip
|
|
img = img * 2.0 - 1.0
|
|
|
|
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
|
if is_channels_first:
|
|
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
|
|
|
images.append(img)
|
|
# Create mask (all ones for real images)
|
|
bsize = img.shape[0]
|
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
img_masks.append(mask)
|
|
|
|
# Create image features not present in the batch as fully 0 padded images
|
|
for _num_empty_cameras in range(len(missing_img_keys)):
|
|
img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
|
|
mask = torch.zeros_like(mask) # Mask is zero for empty cameras
|
|
images.append(img)
|
|
img_masks.append(mask)
|
|
|
|
return images, img_masks
|
|
|
|
def prepare_action(self, batch):
|
|
"""Pad action"""
|
|
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
|
return actions
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Select a single action given environment observations."""
|
|
self.eval()
|
|
|
|
# Action queue logic for n_action_steps > 1
|
|
if len(self._action_queue) == 0:
|
|
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
|
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
|
self._action_queue.extend(actions.transpose(0, 1))
|
|
|
|
return self._action_queue.popleft()
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Predict a chunk of actions given environment observations."""
|
|
self.eval()
|
|
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
# Sample actions using the model (no separate state needed for PI05)
|
|
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
|
|
|
# Unpad actions to actual action dimension
|
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
|
actions = actions[:, :, :original_action_dim]
|
|
|
|
return actions
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
|
"""Run the batch through the model and compute the loss for training."""
|
|
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
actions = self.prepare_action(batch)
|
|
|
|
# Compute loss (no separate state needed for PI05)
|
|
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
|
|
|
# Truncate losses to actual action dimensions
|
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
|
losses = losses[:, :, :original_action_dim]
|
|
|
|
loss = losses.mean()
|
|
|
|
loss_dict = {
|
|
"loss": loss.item(),
|
|
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
|
}
|
|
|
|
return loss, loss_dict
|