mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
1214 lines
49 KiB
Python
1214 lines
49 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
import math
|
|
from collections import deque
|
|
from typing import Literal
|
|
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from torch import Tensor, nn
|
|
from transformers import AutoTokenizer
|
|
from transformers.models.auto import CONFIG_MAPPING
|
|
from transformers.models.gemma import modeling_gemma
|
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
|
|
from lerobot.constants import ACTION, OBS_STATE
|
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
|
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
|
|
|
|
# Helper functions
|
|
def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (exact copy)
|
|
"""Get a safe dtype for the given device type."""
|
|
if device_type == "cpu":
|
|
# CPU doesn't support bfloat16, use float32 instead
|
|
if target_dtype == torch.bfloat16:
|
|
return torch.float32
|
|
if target_dtype == torch.float64:
|
|
return torch.float64
|
|
return target_dtype
|
|
|
|
|
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
|
) -> Tensor:
|
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
|
if dimension % 2 != 0:
|
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
|
|
if time.ndim != 1:
|
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
|
|
dtype = get_safe_dtype(torch.float64, device.type)
|
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
period = min_period * (max_period / min_period) ** fraction
|
|
|
|
# Compute the outer product
|
|
scaling_factor = 1.0 / period * 2 * math.pi
|
|
sin_input = scaling_factor[None, :] * time[:, None]
|
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
|
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
|
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
|
return dist.sample((bsize,))
|
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
|
"""Copied from big_vision.
|
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
setup several types of attention, for example:
|
|
|
|
[[1 1 1 1 1 1]]: pure causal attention.
|
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
themselves and the last 3 tokens have a causal attention. The first
|
|
entry could also be a 1 without changing behaviour.
|
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
block can attend all previous blocks and all tokens on the same block.
|
|
|
|
Args:
|
|
input_mask: bool[B, N] true if its part of the input, false if padding.
|
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
it and 0 where it shares the same attention mask as the previous token.
|
|
"""
|
|
if att_masks.ndim != 2:
|
|
raise ValueError(att_masks.ndim)
|
|
if pad_masks.ndim != 2:
|
|
raise ValueError(pad_masks.ndim)
|
|
|
|
cumsum = torch.cumsum(att_masks, dim=1)
|
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
return att_2d_masks & pad_2d_masks
|
|
|
|
|
|
def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|
images: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
mode: str = "bilinear",
|
|
) -> torch.Tensor:
|
|
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
|
|
Args:
|
|
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
|
height: Target height
|
|
width: Target width
|
|
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
|
|
Returns:
|
|
Resized and padded tensor with same shape format as input
|
|
"""
|
|
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
|
if images.shape[-1] <= 4: # Assume channels-last format
|
|
channels_last = True
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
|
else:
|
|
channels_last = False
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0) # Add batch dimension
|
|
|
|
batch_size, channels, cur_height, cur_width = images.shape
|
|
|
|
# Calculate resize ratio
|
|
ratio = max(cur_width / width, cur_height / height)
|
|
resized_height = int(cur_height / ratio)
|
|
resized_width = int(cur_width / ratio)
|
|
|
|
# Resize
|
|
resized_images = F.interpolate(
|
|
images,
|
|
size=(resized_height, resized_width),
|
|
mode=mode,
|
|
align_corners=False if mode == "bilinear" else None,
|
|
)
|
|
|
|
# Handle dtype-specific clipping
|
|
if images.dtype == torch.uint8:
|
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
|
elif images.dtype == torch.float32:
|
|
resized_images = resized_images.clamp(-1.0, 1.0)
|
|
else:
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
|
# Calculate padding
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
|
pad_h1 = pad_h0 + remainder_h
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
|
pad_w1 = pad_w0 + remainder_w
|
|
|
|
# Pad
|
|
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
|
padded_images = F.pad(
|
|
resized_images,
|
|
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
|
mode="constant",
|
|
value=constant_value,
|
|
)
|
|
|
|
# Convert back to original format if needed
|
|
if channels_last:
|
|
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
if batch_size == 1 and images.shape[0] == 1:
|
|
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
|
|
|
return padded_images
|
|
|
|
|
|
class GemmaConfig: # see openpi `gemma.py: Config`
|
|
"""Configuration for Gemma model variants."""
|
|
|
|
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
|
self.width = width
|
|
self.depth = depth
|
|
self.mlp_dim = mlp_dim
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.head_dim = head_dim
|
|
|
|
|
|
def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
|
|
"""Returns config for specified gemma variant."""
|
|
if variant == "gemma_300m":
|
|
return GemmaConfig(
|
|
width=1024,
|
|
depth=18,
|
|
mlp_dim=4096,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
elif variant == "gemma_2b":
|
|
return GemmaConfig(
|
|
width=2048,
|
|
depth=18,
|
|
mlp_dim=16_384,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown variant: {variant}")
|
|
|
|
|
|
class PaliGemmaWithExpertModel(
|
|
nn.Module
|
|
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
|
"""PaliGemma model with action expert for PI0."""
|
|
|
|
def __init__(
|
|
self,
|
|
vlm_config,
|
|
action_expert_config,
|
|
use_adarms=None,
|
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
|
):
|
|
if use_adarms is None:
|
|
use_adarms = [False, False]
|
|
super().__init__()
|
|
|
|
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
|
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
|
vlm_config_hf.image_token_index = 257152
|
|
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
|
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
|
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
|
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
|
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
|
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
|
vlm_config_hf.text_config.torch_dtype = "float32"
|
|
vlm_config_hf.text_config.vocab_size = 257152
|
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
|
vlm_config_hf.vision_config.projection_dim = 2048
|
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
|
vlm_config_hf.vision_config.torch_dtype = "float32"
|
|
|
|
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
|
head_dim=action_expert_config.head_dim,
|
|
hidden_size=action_expert_config.width,
|
|
intermediate_size=action_expert_config.mlp_dim,
|
|
num_attention_heads=action_expert_config.num_heads,
|
|
num_hidden_layers=action_expert_config.depth,
|
|
num_key_value_heads=action_expert_config.num_kv_heads,
|
|
vocab_size=257152,
|
|
hidden_activation="gelu_pytorch_tanh",
|
|
torch_dtype="float32",
|
|
use_adarms=use_adarms[1],
|
|
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
|
)
|
|
|
|
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
|
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
|
self.gemma_expert.model.embed_tokens = None
|
|
|
|
self.to_bfloat16_for_selected_params(precision)
|
|
|
|
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
|
if precision == "bfloat16":
|
|
self.to(dtype=torch.bfloat16)
|
|
elif precision == "float32":
|
|
self.to(dtype=torch.float32)
|
|
return
|
|
else:
|
|
raise ValueError(f"Invalid precision: {precision}")
|
|
|
|
params_to_keep_float32 = [
|
|
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
|
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
|
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
|
"input_layernorm",
|
|
"post_attention_layernorm",
|
|
"model.norm",
|
|
]
|
|
|
|
for name, param in self.named_parameters():
|
|
if any(selector in name for selector in params_to_keep_float32):
|
|
param.data = param.data.to(dtype=torch.float32)
|
|
|
|
def embed_image(self, image: torch.Tensor):
|
|
return self.paligemma.model.get_image_features(image)
|
|
|
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
|
return self.paligemma.language_model.embed_tokens(tokens)
|
|
|
|
def forward(
|
|
self,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: list[torch.FloatTensor] | None = None,
|
|
inputs_embeds: list[torch.FloatTensor] | None = None,
|
|
use_cache: bool | None = None,
|
|
adarms_cond: list[torch.Tensor] | None = None,
|
|
):
|
|
if adarms_cond is None:
|
|
adarms_cond = [None, None]
|
|
if inputs_embeds[1] is None:
|
|
prefix_output = self.paligemma.language_model.forward(
|
|
inputs_embeds=inputs_embeds[0],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
|
)
|
|
prefix_past_key_values = prefix_output.past_key_values
|
|
prefix_output = prefix_output.last_hidden_state
|
|
suffix_output = None
|
|
elif inputs_embeds[0] is None:
|
|
suffix_output = self.gemma_expert.model.forward(
|
|
inputs_embeds=inputs_embeds[1],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
|
)
|
|
suffix_output = suffix_output.last_hidden_state
|
|
prefix_output = None
|
|
prefix_past_key_values = None
|
|
else:
|
|
models = [self.paligemma.language_model, self.gemma_expert.model]
|
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
|
|
# Check if gradient checkpointing is enabled for any of the models
|
|
use_gradient_checkpointing = (
|
|
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
|
and self.gemma_expert.model.gradient_checkpointing
|
|
and self.training
|
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
|
|
|
# Define the complete layer computation function for gradient checkpointing
|
|
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
|
models = [self.paligemma.language_model, self.gemma_expert.model]
|
|
|
|
query_states = []
|
|
key_states = []
|
|
value_states = []
|
|
gates = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
|
gates.append(gate)
|
|
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
query_states.append(query_state)
|
|
key_states.append(key_state)
|
|
value_states.append(value_state)
|
|
|
|
# Concatenate and process attention
|
|
query_states = torch.cat(query_states, dim=2)
|
|
key_states = torch.cat(key_states, dim=2)
|
|
value_states = torch.cat(value_states, dim=2)
|
|
|
|
dummy_tensor = torch.zeros(
|
|
query_states.shape[0],
|
|
query_states.shape[2],
|
|
query_states.shape[-1],
|
|
device=query_states.device,
|
|
dtype=query_states.dtype,
|
|
)
|
|
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
|
)
|
|
|
|
batch_size = query_states.shape[0]
|
|
scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
|
|
|
|
# Attention computation
|
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
|
self.paligemma.language_model.layers[layer_idx].self_attn,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
scaling,
|
|
)
|
|
# Get head_dim from the current layer, not from the model
|
|
head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
|
|
|
# Process layer outputs
|
|
outputs_embeds = []
|
|
start_pos = 0
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
layer = models[i].layers[layer_idx]
|
|
end_pos = start_pos + hidden_states.shape[1]
|
|
|
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
|
|
|
# first residual
|
|
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
|
after_first_residual = out_emb.clone()
|
|
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
|
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
|
|
|
out_emb = layer.mlp(out_emb)
|
|
# second residual
|
|
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
|
outputs_embeds.append(out_emb)
|
|
start_pos = end_pos
|
|
|
|
return outputs_embeds
|
|
|
|
# Process all layers with gradient checkpointing if enabled
|
|
for layer_idx in range(num_layers):
|
|
if use_gradient_checkpointing:
|
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_layer_complete,
|
|
layer_idx,
|
|
inputs_embeds,
|
|
attention_mask,
|
|
position_ids,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
)
|
|
else:
|
|
inputs_embeds = compute_layer_complete(
|
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
|
)
|
|
|
|
# final norm
|
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
|
outputs_embeds = []
|
|
for i, hidden_states in enumerate(inputs_embeds):
|
|
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
|
outputs_embeds.append(out_emb)
|
|
return outputs_embeds
|
|
|
|
# Apply gradient checkpointing to final norm if enabled
|
|
if use_gradient_checkpointing:
|
|
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
compute_final_norms,
|
|
inputs_embeds,
|
|
adarms_cond,
|
|
use_reentrant=False,
|
|
preserve_rng_state=False,
|
|
)
|
|
else:
|
|
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
|
|
prefix_output = outputs_embeds[0]
|
|
suffix_output = outputs_embeds[1]
|
|
prefix_past_key_values = None
|
|
|
|
return [prefix_output, suffix_output], prefix_past_key_values
|
|
|
|
|
|
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
"""Core PI0 PyTorch model."""
|
|
|
|
def __init__(self, config: PI0OpenPIConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
|
|
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
|
paligemma_config,
|
|
action_expert_config,
|
|
use_adarms=[False, False],
|
|
precision=config.dtype,
|
|
)
|
|
|
|
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
|
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
|
|
|
self.state_proj = nn.Linear(32, action_expert_config.width)
|
|
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
|
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
|
|
# Initialize gradient checkpointing flag
|
|
self.gradient_checkpointing_enabled = False
|
|
|
|
# Compile model if requested
|
|
if config.compile_model:
|
|
torch.set_float32_matmul_precision("high")
|
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
"""Enable gradient checkpointing for memory optimization."""
|
|
self.gradient_checkpointing_enabled = True
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
|
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
|
|
|
def gradient_checkpointing_disable(self):
|
|
"""Disable gradient checkpointing."""
|
|
self.gradient_checkpointing_enabled = False
|
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
|
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
|
|
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
|
"""Helper method to apply gradient checkpointing if enabled."""
|
|
if self.gradient_checkpointing_enabled and self.training:
|
|
return torch.utils.checkpoint.checkpoint(
|
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
def _prepare_attention_masks_4d(self, att_2d_masks):
|
|
"""Helper method to prepare 4D attention masks for transformer."""
|
|
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
|
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
|
|
|
def sample_noise(self, shape, device):
|
|
return torch.normal(
|
|
mean=0.0,
|
|
std=1.0,
|
|
size=shape,
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
|
|
def sample_time(self, bsize, device):
|
|
time_beta = sample_beta(
|
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
|
)
|
|
time = time_beta * 0.999 + 0.001
|
|
return time.to(dtype=torch.float32, device=device)
|
|
|
|
def embed_prefix(
|
|
self, images, img_masks, lang_tokens, lang_masks
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Embed images with SigLIP and language tokens with embedding layer."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
# Process images
|
|
for img, img_mask in zip(images, img_masks, strict=True):
|
|
|
|
def image_embed_func(img):
|
|
return self.paligemma_with_expert.embed_image(img)
|
|
|
|
img_emb = self._apply_checkpoint(image_embed_func, img)
|
|
bsize, num_img_embs = img_emb.shape[:2]
|
|
|
|
embs.append(img_emb)
|
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
|
att_masks += [0] * num_img_embs
|
|
|
|
# Process language tokens
|
|
def lang_embed_func(lang_tokens):
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
|
lang_emb_dim = lang_emb.shape[-1]
|
|
return lang_emb * math.sqrt(lang_emb_dim)
|
|
|
|
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
|
embs.append(lang_emb)
|
|
pad_masks.append(lang_masks)
|
|
|
|
num_lang_embs = lang_emb.shape[1]
|
|
att_masks += [0] * num_lang_embs
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
|
|
bsize = pad_masks.shape[0]
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks
|
|
|
|
def embed_suffix(self, state, noisy_actions, timestep):
|
|
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
|
embs = []
|
|
pad_masks = []
|
|
att_masks = []
|
|
|
|
if self.state_proj.weight.dtype == torch.float32:
|
|
state = state.to(torch.float32)
|
|
|
|
def state_proj_func(state):
|
|
return self.state_proj(state)
|
|
|
|
state_emb = self._apply_checkpoint(state_proj_func, state)
|
|
embs.append(state_emb[:, None, :])
|
|
bsize = state_emb.shape[0]
|
|
device = state_emb.device
|
|
|
|
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
|
pad_masks.append(state_mask)
|
|
att_masks += [1]
|
|
|
|
# Embed timestep using sine-cosine positional encoding
|
|
time_emb = create_sinusoidal_pos_embedding(
|
|
timestep,
|
|
self.action_in_proj.out_features,
|
|
min_period=self.config.min_period,
|
|
max_period=self.config.max_period,
|
|
device=timestep.device,
|
|
)
|
|
time_emb = time_emb.type(dtype=timestep.dtype)
|
|
|
|
# Fuse timestep + action information using an MLP
|
|
def action_proj_func(noisy_actions):
|
|
return self.action_in_proj(noisy_actions)
|
|
|
|
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
|
|
|
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
|
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
|
|
|
def mlp_func(action_time_emb):
|
|
x = self.action_time_mlp_in(action_time_emb)
|
|
x = F.silu(x)
|
|
return self.action_time_mlp_out(x)
|
|
|
|
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
|
adarms_cond = None
|
|
|
|
embs.append(action_time_emb)
|
|
bsize, action_time_dim = action_time_emb.shape[:2]
|
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
|
pad_masks.append(action_time_mask)
|
|
|
|
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
|
|
return embs, pad_masks, att_masks, adarms_cond
|
|
|
|
def forward(
|
|
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
|
) -> Tensor:
|
|
"""Do a full training forward pass and compute the loss."""
|
|
if noise is None:
|
|
noise = self.sample_noise(actions.shape, actions.device)
|
|
|
|
if time is None:
|
|
time = self.sample_time(actions.shape[0], actions.device)
|
|
|
|
time_expanded = time[:, None, None]
|
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
u_t = noise - actions
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
images, img_masks, lang_tokens, lang_masks
|
|
)
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
|
|
|
if (
|
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
|
== torch.bfloat16
|
|
):
|
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
|
|
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
|
|
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
|
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
|
|
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
attention_mask=att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
return suffix_out
|
|
|
|
suffix_out = self._apply_checkpoint(
|
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
|
)
|
|
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
|
|
def action_out_proj_func(suffix_out):
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
|
|
|
return F.mse_loss(u_t, v_t, reduction="none")
|
|
|
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
|
def sample_actions(
|
|
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
|
) -> Tensor:
|
|
"""Do a full inference forward and compute the action."""
|
|
if num_steps is None:
|
|
num_steps = self.config.num_inference_steps
|
|
|
|
bsize = state.shape[0]
|
|
device = state.device
|
|
|
|
if noise is None:
|
|
# Sample noise with padded dimension (32) as expected by action_in_proj
|
|
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
|
noise = self.sample_noise(actions_shape, device)
|
|
|
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
images, img_masks, lang_tokens, lang_masks
|
|
)
|
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
|
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
|
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
_, past_key_values = self.paligemma_with_expert.forward(
|
|
attention_mask=prefix_att_2d_masks_4d,
|
|
position_ids=prefix_position_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=[prefix_embs, None],
|
|
use_cache=True,
|
|
)
|
|
|
|
dt = -1.0 / num_steps
|
|
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
|
|
x_t = noise
|
|
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
while time >= -dt / 2:
|
|
expanded_time = time.expand(bsize)
|
|
v_t = self.denoise_step(
|
|
state,
|
|
prefix_pad_masks,
|
|
past_key_values,
|
|
x_t,
|
|
expanded_time,
|
|
)
|
|
x_t = x_t + dt * v_t
|
|
time += dt
|
|
|
|
# Truncate to actual action dimension before returning
|
|
if self.config.action_dim < 32:
|
|
x_t = x_t[:, :, : self.config.action_dim]
|
|
|
|
return x_t
|
|
|
|
def denoise_step(
|
|
self,
|
|
state,
|
|
prefix_pad_masks,
|
|
past_key_values,
|
|
x_t,
|
|
timestep,
|
|
):
|
|
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
|
|
|
suffix_len = suffix_pad_masks.shape[1]
|
|
batch_size = prefix_pad_masks.shape[0]
|
|
prefix_len = prefix_pad_masks.shape[1]
|
|
|
|
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
|
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
|
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
|
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
|
|
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
attention_mask=full_att_2d_masks_4d,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=[None, suffix_embs],
|
|
use_cache=False,
|
|
adarms_cond=[None, adarms_cond],
|
|
)
|
|
|
|
suffix_out = outputs_embeds[1]
|
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
return self.action_out_proj(suffix_out)
|
|
|
|
|
|
class PI0OpenPIPolicy(PreTrainedPolicy):
|
|
"""PI0 OpenPI Policy for LeRobot."""
|
|
|
|
config_class = PI0OpenPIConfig
|
|
name = "pi0_openpi"
|
|
|
|
def __init__( # see lerobot pi0 `__init__`
|
|
self,
|
|
config: PI0OpenPIConfig,
|
|
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
):
|
|
"""
|
|
Args:
|
|
config: Policy configuration class instance.
|
|
dataset_stats: Dataset statistics to be used for normalization.
|
|
"""
|
|
super().__init__(config)
|
|
config.validate_features()
|
|
self.config = config
|
|
|
|
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
|
self.normalize_targets = Normalize(
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
)
|
|
self.unnormalize_outputs = Unnormalize(
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
)
|
|
|
|
# Create tokenizer for language input
|
|
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
|
|
|
# Set max token length for tokenizer (from OpenPI)
|
|
self.max_token_len = config.tokenizer_max_length
|
|
|
|
# Initialize the core PI0 model
|
|
self.model = PI0Pytorch(config)
|
|
|
|
# Enable gradient checkpointing if requested
|
|
if config.gradient_checkpointing:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
self.reset()
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls, *args, **kwargs
|
|
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
|
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
|
print(
|
|
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
|
" This implementation follows the original OpenPI structure for compatibility. \n"
|
|
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
|
)
|
|
|
|
# Store original strict mode
|
|
original_strict = kwargs.get("strict", True)
|
|
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
|
|
kwargs["strict"] = False
|
|
|
|
# Call parent from_pretrained with strict=False
|
|
model = super().from_pretrained(*args, **kwargs)
|
|
|
|
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
|
|
if len(args) > 0:
|
|
pretrained_model_name_or_path = args[0]
|
|
elif "pretrained_model_name_or_path" in kwargs:
|
|
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
|
else:
|
|
return model
|
|
|
|
# Now manually load and remap the state dict
|
|
try:
|
|
from transformers.utils import cached_file
|
|
|
|
# Try to load the pytorch_model.bin or model.safetensors file
|
|
print(f"Loading model from: {pretrained_model_name_or_path}")
|
|
try:
|
|
# Try safetensors first
|
|
resolved_file = cached_file(
|
|
pretrained_model_name_or_path,
|
|
"model.safetensors",
|
|
cache_dir=kwargs.get("cache_dir"),
|
|
force_download=kwargs.get("force_download", False),
|
|
resume_download=kwargs.get("resume_download"),
|
|
proxies=kwargs.get("proxies"),
|
|
use_auth_token=kwargs.get("use_auth_token"),
|
|
revision=kwargs.get("revision"),
|
|
local_files_only=kwargs.get("local_files_only", False),
|
|
)
|
|
from safetensors.torch import load_file
|
|
|
|
original_state_dict = load_file(resolved_file)
|
|
print("✓ Loaded state dict from model.safetensors")
|
|
except Exception as e:
|
|
print(f"Could not load state dict from remote files: {e}")
|
|
return model
|
|
|
|
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
|
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
|
|
|
# Then add "model." prefix for all keys that don't already have it
|
|
remapped_state_dict = {}
|
|
remap_count = 0
|
|
|
|
for key, value in fixed_state_dict.items():
|
|
if not key.startswith("model."):
|
|
new_key = f"model.{key}"
|
|
remapped_state_dict[new_key] = value
|
|
remap_count += 1
|
|
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
print(f"Remapped: {key} -> {new_key}")
|
|
else:
|
|
remapped_state_dict[key] = value
|
|
|
|
if remap_count > 10:
|
|
print(f"... and {remap_count - 10} more keys remapped")
|
|
|
|
print(f"Total keys remapped: {remap_count}")
|
|
|
|
# Load the remapped state dict into the model
|
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
|
|
|
|
if missing_keys:
|
|
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
if len(missing_keys) <= 5:
|
|
for key in missing_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in missing_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(missing_keys) - 5} more")
|
|
|
|
if unexpected_keys:
|
|
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
if len(unexpected_keys) <= 5:
|
|
for key in unexpected_keys:
|
|
print(f" - {key}")
|
|
else:
|
|
for key in unexpected_keys[:5]:
|
|
print(f" - {key}")
|
|
print(f" ... and {len(unexpected_keys) - 5} more")
|
|
|
|
if not missing_keys and not unexpected_keys:
|
|
print("✅ All keys loaded successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
|
|
print("Using default loading behavior")
|
|
|
|
return model
|
|
|
|
def _fix_pytorch_state_dict_keys(
|
|
self, state_dict, model_config
|
|
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
|
"""Fix state dict keys to match current model architecture."""
|
|
import re
|
|
|
|
fixed_state_dict = {}
|
|
|
|
for key, value in state_dict.items():
|
|
new_key = key
|
|
|
|
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
|
# For gemma expert layers
|
|
if re.match(
|
|
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
|
key,
|
|
):
|
|
# This key structure suggests old model without adaRMS - keep as is or skip
|
|
logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}")
|
|
continue
|
|
|
|
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
|
# Skip old norm structure
|
|
logging.warning(f"Skipping old norm key (no adaRMS support): {key}")
|
|
continue
|
|
|
|
# Handle MLP naming changes for pi0
|
|
# non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_*
|
|
if key.startswith("time_mlp_in."):
|
|
new_key = key.replace("time_mlp_in.", "action_time_mlp_in.")
|
|
elif key.startswith("time_mlp_out."):
|
|
new_key = key.replace("time_mlp_out.", "action_time_mlp_out.")
|
|
|
|
# Handle vision tower embedding layer potential differences
|
|
if "patch_embedding" in key:
|
|
# Some checkpoints might have this, but current model expects different structure
|
|
logging.warning(f"Vision embedding key might need handling: {key}")
|
|
|
|
fixed_state_dict[new_key] = value
|
|
|
|
return fixed_state_dict
|
|
|
|
def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params`
|
|
return self.parameters()
|
|
|
|
def reset(self): # see lerobot pi0 `reset`
|
|
"""Reset internal state - called when environment resets."""
|
|
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
|
self._queues = {
|
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
|
}
|
|
|
|
def _preprocess_images(
|
|
self, batch: dict[str, Tensor]
|
|
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
|
|
"""Preprocess images for the model.
|
|
|
|
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
|
|
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
|
|
"""
|
|
images = []
|
|
img_masks = []
|
|
|
|
# Get device from model parameters
|
|
device = next(self.parameters()).device
|
|
|
|
for key in self.config.image_keys:
|
|
if key in batch:
|
|
img = batch[key]
|
|
|
|
# Ensure tensor is on the same device as the model
|
|
if img.device != device:
|
|
img = img.to(device)
|
|
|
|
# Ensure float32 dtype for consistency
|
|
if img.dtype != torch.float32:
|
|
img = img.to(torch.float32)
|
|
|
|
# Check if image is in [B, C, H, W] format (channels first)
|
|
if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB
|
|
# Already in correct format
|
|
pass
|
|
elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format
|
|
# Convert to [B, C, H, W]
|
|
img = img.permute(0, 3, 1, 2)
|
|
else:
|
|
raise ValueError(f"Unexpected image shape {img.shape} for key {key}")
|
|
|
|
# Resize with padding if needed
|
|
if img.shape[-2:] != self.config.image_resolution:
|
|
# resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats
|
|
# But we need to ensure we pass it in the right format
|
|
img = resize_with_pad_torch(
|
|
img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function
|
|
*self.config.image_resolution,
|
|
).permute(0, 3, 1, 2) # Convert back to [B, C, H, W]
|
|
|
|
# Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma
|
|
# Check if normalization is needed
|
|
if img.min() >= 0 and img.max() <= 1:
|
|
img = img * 2.0 - 1.0
|
|
elif img.min() >= -1 and img.max() <= 1:
|
|
# Already normalized to [-1, 1]
|
|
pass
|
|
else:
|
|
# Assume it's in [0, 255] range and normalize
|
|
img = (img / 255.0) * 2.0 - 1.0
|
|
|
|
images.append(img)
|
|
# Create mask (all ones for real images)
|
|
img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device))
|
|
|
|
return images, img_masks
|
|
|
|
def _tokenize_language(
|
|
self, batch: dict[str, Tensor]
|
|
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
|
|
"""Tokenize language input using PaliGemma tokenizer."""
|
|
device = next(self.parameters()).device
|
|
|
|
# Get task description
|
|
if "task" in batch:
|
|
tasks = batch["task"]
|
|
if isinstance(tasks, str):
|
|
tasks = [tasks]
|
|
elif isinstance(tasks, list) and len(tasks) == 1:
|
|
# Expand to batch size
|
|
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
|
tasks = tasks * batch_size
|
|
else:
|
|
# Default task if not provided
|
|
batch_size = batch[next(iter(batch.keys()))].shape[0]
|
|
tasks = ["Pick up the object"] * batch_size
|
|
|
|
# Tokenize with max_length padding to match OpenPI's expected format
|
|
tokenized = self.tokenizer(
|
|
tasks,
|
|
padding="max_length", # Use max_length padding as per OpenPI
|
|
padding_side="right", # from lerobot pi0 `prepare_language`
|
|
truncation=True,
|
|
max_length=self.max_token_len, # Use the max token length from config
|
|
return_tensors="pt",
|
|
)
|
|
|
|
lang_tokens = tokenized["input_ids"].to(device)
|
|
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
|
|
|
return lang_tokens, lang_masks
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action`
|
|
"""Select a single action given environment observations."""
|
|
self.eval()
|
|
|
|
# Action queue logic for n_action_steps > 1
|
|
if len(self._action_queue) == 0:
|
|
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
|
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
|
self._action_queue.extend(actions.transpose(0, 1))
|
|
|
|
return self._action_queue.popleft()
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action`
|
|
"""Predict a chunk of actions given environment observations."""
|
|
self.eval()
|
|
|
|
batch = self.normalize_inputs(batch)
|
|
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
lang_tokens, lang_masks = self._tokenize_language(batch)
|
|
state = batch[OBS_STATE]
|
|
|
|
# Validate state dimension
|
|
if state.shape[-1] > 32:
|
|
raise ValueError(
|
|
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
|
f"Please reduce state dimension or modify the model."
|
|
)
|
|
|
|
# Pad state to 32 dimensions if needed (PI0 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state`
|
|
if state.shape[-1] < 32:
|
|
padding = torch.zeros(
|
|
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
|
)
|
|
state = torch.cat([state, padding], dim=-1)
|
|
|
|
# Sample actions using the model
|
|
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
|
|
|
# Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action`
|
|
if self.config.action_dim < 32:
|
|
actions = actions[:, :, : self.config.action_dim]
|
|
|
|
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
|
return actions
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward`
|
|
"""Run the batch through the model and compute the loss for training."""
|
|
batch = self.normalize_inputs(batch)
|
|
batch = self.normalize_targets(batch)
|
|
|
|
# Prepare inputs
|
|
images, img_masks = self._preprocess_images(batch)
|
|
lang_tokens, lang_masks = self._tokenize_language(batch)
|
|
state = batch[OBS_STATE]
|
|
actions = batch[ACTION]
|
|
|
|
# Validate state and action dimensions
|
|
if state.shape[-1] > 32:
|
|
raise ValueError(
|
|
f"State dimension {state.shape[-1]} exceeds maximum of 32. "
|
|
f"Please reduce state dimension or modify the model."
|
|
)
|
|
if actions.shape[-1] > 32:
|
|
raise ValueError(
|
|
f"Action dimension {actions.shape[-1]} exceeds maximum of 32. "
|
|
f"Please reduce action dimension or modify the model."
|
|
)
|
|
|
|
# Pad state and actions to 32 dimensions if needed (PI0 expects fixed 32-dim)
|
|
if state.shape[-1] < 32:
|
|
padding = torch.zeros(
|
|
state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype
|
|
)
|
|
state = torch.cat([state, padding], dim=-1)
|
|
|
|
if actions.shape[-1] < 32:
|
|
padding = torch.zeros(
|
|
*actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype
|
|
)
|
|
actions = torch.cat([actions, padding], dim=-1)
|
|
|
|
# Compute loss
|
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
|
|
|
# Truncate losses to actual action dimensions
|
|
if self.config.action_dim < 32:
|
|
losses = losses[:, :, : self.config.action_dim]
|
|
|
|
loss = losses.mean()
|
|
|
|
loss_dict = {
|
|
"loss": loss.item(),
|
|
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
|
}
|
|
|
|
return loss, loss_dict
|