From 419305a4c29d4757b4e284427ce1bafdeb05892d Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 23 Feb 2026 22:44:13 +0300 Subject: [PATCH] Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun --- docs/source/pi0fast.mdx | 20 +- src/lerobot/policies/pi0/modeling_pi0.py | 123 +++--- src/lerobot/policies/pi05/modeling_pi05.py | 125 +++--- src/lerobot/policies/pi05/processor_pi05.py | 4 - .../pi0_fast/configuration_pi0_fast.py | 2 +- .../policies/pi0_fast/modeling_pi0_fast.py | 53 ++- .../policies/pi0_fast/processor_pi0_fast.py | 4 - src/lerobot/policies/pi_gemma.py | 363 ++++++++++++++++++ src/lerobot/processor/tokenizer_processor.py | 2 +- .../scripts/lerobot_train_tokenizer.py | 2 +- .../test_pi0_fast_original_vs_lerobot.py | 10 +- tests/policies/pi0_pi05/test_pi0.py | 2 +- tests/policies/pi0_pi05/test_pi05.py | 2 +- 13 files changed, 517 insertions(+), 195 deletions(-) create mode 100644 src/lerobot/policies/pi_gemma.py diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx index c4230fa79..85d975924 100644 --- a/docs/source/pi0fast.mdx +++ b/docs/source/pi0fast.mdx @@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr You have two options for the FAST tokenizer: -1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. +1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. 2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data. @@ -114,15 +114,15 @@ lerobot-train \ ### Key Training Parameters -| Parameter | Description | Default | -| -------------------------------------- | -------------------------------------------------- | ---------------------------- | -| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | -| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | -| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | -| `--policy.n_action_steps` | Number of action steps to execute | `50` | -| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | -| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` | -| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | +| Parameter | Description | Default | +| -------------------------------------- | -------------------------------------------------- | ------------------------------- | +| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | +| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | +| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | +| `--policy.n_action_steps` | Number of action steps to execute | `50` | +| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | +| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` | +| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | ## Inference diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a8ae83c95..2f77e9517 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins +import copy import logging import math from collections import deque @@ -32,13 +33,21 @@ from lerobot.utils.import_utils import _transformers_available if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from transformers.models.gemma.modeling_gemma import GemmaForCausalLM - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, + ) else: CONFIG_MAPPING = None modeling_gemma = None - GemmaForCausalLM = None - PaliGemmaForConditionalGeneration = None + PiGemmaForCausalLM = None + _gated_residual = None + layernorm_forward = None + PaliGemmaForConditionalGenerationWithPiGemma = None + from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config @@ -59,11 +68,6 @@ class ActionSelectKwargs(TypedDict, total=False): execution_horizon: int | None -def _gated_residual(residual: torch.Tensor, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - """Gated residual connection: residual + gate * hidden_states.""" - return residual + gate.unsqueeze(-1) * hidden_states - - def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -196,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -207,7 +211,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -222,35 +226,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images -class AdaRMSNorm(nn.Module): - """RMSNorm wrapper that supports optional AdaRMS conditioning. - - When called with `cond=None`, behaves like standard RMSNorm and returns a gate of ones. - When called with a conditioning tensor, applies AdaRMS: uses a linear projection to produce - a scale and gate from the conditioning input. - """ - - def __init__(self, base_norm: nn.Module, cond_dim: int | None = None): - super().__init__() - self.base_norm = base_norm - if cond_dim is not None: - hidden_size = base_norm.weight.shape[0] - self.ada_proj = nn.Linear(cond_dim, 2 * hidden_size, bias=False) - nn.init.zeros_(self.ada_proj.weight) - else: - self.ada_proj = None - - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None): - normed = self.base_norm(x) - if cond is None or self.ada_proj is None: - gate = torch.ones(x.shape[:-1], dtype=x.dtype, device=x.device) - return normed, gate - scale_gate = self.ada_proj(cond) - scale, gate = scale_gate.chunk(2, dim=-1) - normed = normed * (1 + scale) - return normed, gate - - # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert @@ -262,13 +237,7 @@ def compute_layer_complete( gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - if isinstance(layer.input_layernorm, AdaRMSNorm): - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 - else: - hidden_states = layer.input_layernorm(hidden_states) # noqa: PLW2901 - gate = torch.ones( - hidden_states.shape[:-1], dtype=hidden_states.dtype, device=hidden_states.device - ) + hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -317,19 +286,15 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = _gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() - if isinstance(layer.post_attention_layernorm, AdaRMSNorm): - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - else: - out_emb = layer.post_attention_layernorm(out_emb) - gate = torch.ones(out_emb.shape[:-1], dtype=out_emb.dtype, device=out_emb.device) + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = _gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -402,7 +367,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None @@ -410,7 +375,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -421,13 +386,13 @@ class PaliGemmaWithExpertModel( num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", - torch_dtype="float32", + dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) - self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) @@ -442,10 +407,11 @@ class PaliGemmaWithExpertModel( else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Align with PI05. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -473,7 +439,15 @@ class PaliGemmaWithExpertModel( self.paligemma.eval() def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05. + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.model.language_model.embed_tokens(tokens) @@ -554,11 +528,7 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - norm = models[i].norm - if isinstance(norm, AdaRMSNorm): - out_emb, _ = norm(hidden_states, cond=adarms_cond[i]) - else: - out_emb = norm(hidden_states) + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -946,6 +916,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, @@ -1035,14 +1006,12 @@ class PI0Policy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -1063,7 +1032,7 @@ class PI0Policy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -1108,7 +1077,7 @@ class PI0Policy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model @@ -1158,6 +1127,14 @@ class PI0Policy(PreTrainedPolicy): # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + fixed_state_dict[new_key] = value return fixed_state_dict diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index d5be17079..4a74250a0 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins +import copy import logging import math from collections import deque @@ -32,14 +33,20 @@ from lerobot.utils.import_utils import _transformers_available if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from transformers.models.gemma.modeling_gemma import GemmaForCausalLM - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, + ) else: CONFIG_MAPPING = None modeling_gemma = None - GemmaForCausalLM = None - PaliGemmaForConditionalGeneration = None - + PiGemmaForCausalLM = None + _gated_residual = None + layernorm_forward = None + PaliGemmaForConditionalGenerationWithPiGemma = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -58,11 +65,6 @@ class ActionSelectKwargs(TypedDict, total=False): execution_horizon: int | None -def _gated_residual(residual: torch.Tensor, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - """Gated residual connection: residual + gate * hidden_states.""" - return residual + gate.unsqueeze(-1) * hidden_states - - def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -194,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -205,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -220,35 +222,6 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images -class AdaRMSNorm(nn.Module): - """RMSNorm wrapper that supports optional AdaRMS conditioning. - - When called with `cond=None`, behaves like standard RMSNorm and returns a gate of ones. - When called with a conditioning tensor, applies AdaRMS: uses a linear projection to produce - a scale and gate from the conditioning input. - """ - - def __init__(self, base_norm: nn.Module, cond_dim: int | None = None): - super().__init__() - self.base_norm = base_norm - if cond_dim is not None: - hidden_size = base_norm.weight.shape[0] - self.ada_proj = nn.Linear(cond_dim, 2 * hidden_size, bias=False) - nn.init.zeros_(self.ada_proj.weight) - else: - self.ada_proj = None - - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None): - normed = self.base_norm(x) - if cond is None or self.ada_proj is None: - gate = torch.ones(x.shape[:-1], dtype=x.dtype, device=x.device) - return normed, gate - scale_gate = self.ada_proj(cond) - scale, gate = scale_gate.chunk(2, dim=-1) - normed = normed * (1 + scale) - return normed, gate - - # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert @@ -260,13 +233,7 @@ def compute_layer_complete( gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - if isinstance(layer.input_layernorm, AdaRMSNorm): - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 - else: - hidden_states = layer.input_layernorm(hidden_states) # noqa: PLW2901 - gate = torch.ones( - hidden_states.shape[:-1], dtype=hidden_states.dtype, device=hidden_states.device - ) + hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -315,19 +282,15 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = _gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() - if isinstance(layer.post_attention_layernorm, AdaRMSNorm): - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - else: - out_emb = layer.post_attention_layernorm(out_emb) - gate = torch.ones(out_emb.shape[:-1], dtype=out_emb.dtype, device=out_emb.device) + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = _gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -400,7 +363,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None @@ -408,7 +371,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -419,13 +382,13 @@ class PaliGemmaWithExpertModel( num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", - torch_dtype="float32", + dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) - self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) @@ -440,10 +403,11 @@ class PaliGemmaWithExpertModel( else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Saves memory vs full float32; more memory than only 3 params. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -471,7 +435,15 @@ class PaliGemmaWithExpertModel( self.paligemma.eval() def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.model.language_model.embed_tokens(tokens) @@ -552,11 +524,7 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - norm = models[i].norm - if isinstance(norm, AdaRMSNorm): - out_emb, _ = norm(hidden_states, cond=adarms_cond[i]) - else: - out_emb = norm(hidden_states) + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -918,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, @@ -1007,14 +976,12 @@ class PI05Policy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -1035,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -1047,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam - print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value @@ -1082,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model @@ -1136,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy): # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + fixed_state_dict[new_key] = value return fixed_state_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index e29bc4c23..6e01a4e16 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -23,7 +23,6 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.policies.pi05.configuration_pi05 import PI05Config -from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): # TODO: check if this necessary state = deepcopy(state) - # Prepare state (pad to max_state_dim) - state = pad_vector(state, self.max_state_dim) - # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) state_np = state.cpu().numpy() diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 96137e91f..e12522833 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig): tokenizer_max_length: int = 200 # see openpi `__post_init__` text_tokenizer_name: str = "google/paligemma-3b-pt-224" - action_tokenizer_name: str = "physical-intelligence/fast" + action_tokenizer_name: str = "lerobot/fast-action-tokenizer" temperature: float = 0.0 max_decoding_steps: int = 256 fast_skip_tokens: int = 128 diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index 47e1df8db..52fc2504d 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -38,11 +38,16 @@ else: if TYPE_CHECKING or _transformers_available: from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaModel, + ) else: CONFIG_MAPPING = None - PaliGemmaForConditionalGeneration = None AutoTokenizer = None + PiGemmaModel = None + PaliGemmaForConditionalGenerationWithPiGemma = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig @@ -121,7 +126,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -132,7 +137,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -206,16 +211,22 @@ class PI0FastPaliGemma(nn.Module): vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + + # Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that + # forward(..., adarms_cond=...) is supported (same as pi0/pi05). + if use_adarms[0]: + text_config = self.paligemma.config.text_config + self.paligemma.model.language_model = PiGemmaModel(text_config) self.to_bfloat16_for_selected_params(precision) @@ -228,10 +239,11 @@ class PI0FastPaliGemma(nn.Module): else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Align with PI05. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -242,7 +254,15 @@ class PI0FastPaliGemma(nn.Module): param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05. + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.model.language_model.embed_tokens(tokens) @@ -887,14 +907,12 @@ class PI0FastPolicy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -915,8 +933,9 @@ class PI0FastPolicy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + # Then add "model." prefix for all keys that don't already have it remapped_state_dict = {} remap_count = 0 @@ -926,8 +945,6 @@ class PI0FastPolicy(PreTrainedPolicy): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam - print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value @@ -961,7 +978,7 @@ class PI0FastPolicy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index 0d9dac673..fde7d5c80 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -23,7 +23,6 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig -from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector from lerobot.processor import ( ActionTokenizerProcessorStep, AddBatchDimensionProcessorStep, @@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): # TODO: check if this necessary state = deepcopy(state) - # Prepare state (pad to max_state_dim) - state = pad_vector(state, self.max_state_dim) - # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) state_np = state.cpu().numpy() diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py new file mode 100644 index 000000000..35a6ae0d2 --- /dev/null +++ b/src/lerobot/policies/pi_gemma.py @@ -0,0 +1,363 @@ +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers.cache_utils import DynamicCache + from transformers.masking_utils import create_causal_mask + from transformers.modeling_layers import GradientCheckpointingLayer + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.models.gemma.modeling_gemma import ( + GemmaAttention, + GemmaConfig, + GemmaForCausalLM, + GemmaMLP, + GemmaModel, + ) + from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaForConditionalGeneration, + PaliGemmaModel, + ) +else: + GemmaAttention = None + GemmaConfig = None + GemmaForCausalLM = None + GemmaMLP = None + GemmaModel = None + PaliGemmaModel = None + PaliGemmaForConditionalGeneration = None + DynamicCache = None + GradientCheckpointingLayer = None + BaseModelOutputWithPast = None + create_causal_mask = None + + +def _gated_residual( + x: torch.Tensor | None, + y: torch.Tensor | None, + gate: torch.Tensor | None, +) -> torch.Tensor | None: + """Gated residual: x + y when gate is None, else x + y * gate.""" + if x is None and y is None: + return None + if x is None or y is None: + return x if x is not None else y + if gate is None: + return x + y + return x + y * gate + + +def layernorm_forward( + layernorm: nn.Module, + x: torch.Tensor, + cond: torch.Tensor | None = None, +): + """ + call layernorm and return hidden states and gate + if cond is not None, use conditional norm + otherwise, use normal gemma norm + """ + if cond is not None: + return layernorm(x, cond=cond) + else: + return layernorm(x) + + +class PiGemmaRMSNorm(nn.Module): + """ + Adaptive RMSNorm for PI Gemma (AdaRMS). + When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm. + forward(x, cond=None) returns (output, gate) for use with _gated_residual. + """ + + def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): + super().__init__() + self.eps = eps + self.dim = dim + self.cond_dim = cond_dim + if cond_dim is not None: + self.dense = nn.Linear(cond_dim, dim * 3, bias=True) + nn.init.zeros_(self.dense.weight) + else: + self.weight = nn.Parameter(torch.zeros(dim)) + self.dense = None + + def _norm(self, x): + # Compute variance in float32 (like the source implementation) + var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) + # Compute normalization in float32 + normed_inputs = x * torch.rsqrt(var + self.eps) + return normed_inputs + + def forward( + self, + x: torch.Tensor, + cond: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + dtype = x.dtype + normed = self._norm(x) + if cond is None or self.dense is None: + normed = normed * (1.0 + self.weight.float()) + return normed.type_as(x), None + if cond.shape[-1] != self.cond_dim: + raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}") + modulation = self.dense(cond) + if len(x.shape) == 3: + modulation = modulation.unsqueeze(1) + scale, shift, gate = modulation.chunk(3, dim=-1) + normed = normed * (1 + scale.float()) + shift.float() + return normed.to(dtype), gate.to(dtype) + + def extra_repr(self) -> str: + if self.dense is not None: + return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}" + return f"dim={self.dim}, eps={self.eps}" + + +def _get_pi_gemma_decoder_layer_base(): + """base for PiGemmaDecoderLayer""" + + class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer): + """Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma.""" + + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + self.mlp = GemmaMLP(config) + cond_dim = ( + getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None + ) + self.input_layernorm = PiGemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + self.post_attention_layernorm = PiGemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values=None, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond) + hidden_states, _ = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = _gated_residual(residual, hidden_states, gate) + + residual = hidden_states + hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond) + hidden_states = self.mlp(hidden_states) + hidden_states = _gated_residual(residual, hidden_states, gate) + return hidden_states + + return _PiGemmaDecoderLayerBase + + +class PiGemmaModel(GemmaModel): # type: ignore[misc] + """ + GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True. + """ + + def __init__(self, config: GemmaConfig, **kwargs): + super().__init__(config, **kwargs) + # if not getattr(config, "use_adarms", False): + # return + cond_dim = getattr(config, "adarms_cond_dim", None) + pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base() + self.layers = nn.ModuleList( + [pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + import logging + + logging.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, _ = self.norm(hidden_states, adarms_cond) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc] + """ + Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM + and the language model used in pi0_fast. Use this for the action expert in pi0/pi05. + """ + + def __init__(self, config: GemmaConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PiGemmaModel(config) + + +class PaliGemmaModelWithPiGemma(PaliGemmaModel): + """PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals).""" + + def __init__(self, config): + super().__init__(config) + self.language_model = PiGemmaModel(config.text_config) + + +class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration): + """PaliGemmaForConditionalGeneration using PiGemma decoder for the language model.""" + + def __init__(self, config): + super().__init__(config) + self.model = PaliGemmaModelWithPiGemma(config) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + +__all__ = [ + "PiGemmaModel", + "PiGemmaForCausalLM", + "PiGemmaRMSNorm", + "_gated_residual", + "layernorm_forward", + "PaliGemmaModelWithPiGemma", + "PaliGemmaForConditionalGenerationWithPiGemma", +] diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index df559555a..da6e600af 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): Requires the `transformers` library to be installed. Attributes: - tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast"). + tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer"). tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored. trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers). action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 1d8f4644b..807d48333 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -306,7 +306,7 @@ def train_fast_tokenizer( # download the tokenizer source code (not pretrained weights) # we'll train a new tokenizer on our own data - base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True) # convert action_chunks array to list of arrays (expected by .fit()) action_data_list = [action_chunks[i] for i in range(len(action_chunks))] diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index 9ebc4ba89..7b1bbce7d 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -54,19 +54,19 @@ IMAGE_HEIGHT = 224 IMAGE_WIDTH = 224 NUM_VIEWS = 2 # Number of camera views DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -MODEL_PATH_LEROBOT = "lerobot/pi0fast-base" +MODEL_PATH_LEROBOT = "jadechoghari/pi0fast-base" # Expected action token shape: (batch_size, max_decoding_steps) EXPECTED_ACTION_TOKENS_SHAPE = (1, 2) # Expected first 5 action tokens (for reproducibility check) -EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362]) +EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255425]) # Expected actions after detokenization EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim) -EXPECTED_ACTIONS_MEAN = 0.04419417306780815 -EXPECTED_ACTIONS_STD = 0.26231569051742554 -EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000]) +EXPECTED_ACTIONS_MEAN = 0.046403881162405014 +EXPECTED_ACTIONS_STD = 0.2607129216194153 +EXPECTED_ACTIONS_FIRST_5 = torch.tensor([-0.0707, 1.4849, 0.0000, 0.0000, 0.0000]) def set_seed_all(seed: int): diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py index b580310eb..230e43201 100644 --- a/tests/policies/pi0_pi05/test_pi0.py +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -24,7 +24,7 @@ import torch # Skip this entire module in CI pytestmark = pytest.mark.skipif( os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", + reason="This test requires accepting the model license", ) from lerobot.policies.factory import make_policy_config # noqa: E402 diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py index 964539446..acb616960 100644 --- a/tests/policies/pi0_pi05/test_pi05.py +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -26,7 +26,7 @@ from lerobot.utils.random_utils import set_seed # Skip this entire module in CI pytestmark = pytest.mark.skipif( os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", + reason="This test requires accepting the model license", ) from lerobot.policies.factory import make_policy_config # noqa: E402