diff --git a/src/lerobot/policies/videovla/modeling_pi05.py b/src/lerobot/policies/videovla/modeling_pi05.py index 493bd8e62..29ad1dc1a 100644 --- a/src/lerobot/policies/videovla/modeling_pi05.py +++ b/src/lerobot/policies/videovla/modeling_pi05.py @@ -826,15 +826,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` Returns: Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to - PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model. + PaliGemma's hidden dimension, in the same dtype as the video_resampler. """ if self.video_encoder is None: raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.") device = video_frames.device - # Determine target dtype: match PaliGemma's language model dtype for consistency - paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype + # Determine target dtype: use the video_resampler's dtype for consistency + # Note: We use video_resampler (not embed_tokens) because embed_tokens may be + # tied/missing in some checkpoints, but video_resampler is always initialized + target_dtype = self.video_resampler.ln_kv.weight.dtype # Move video encoder to the same device if needed if next(self.video_encoder.parameters()).device != device: @@ -869,8 +871,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768]) video_embeddings = video_outputs.last_hidden_state - # Convert to PaliGemma's dtype for consistency with the rest of the model - video_embeddings = video_embeddings.to(dtype=paligemma_dtype) + # Convert to target dtype for consistency with the video_resampler + video_embeddings = video_embeddings.to(dtype=target_dtype) # Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128) # This uses cross-attention from learnable queries to the video tokens