fix(videovla): use video_resampler dtype instead of embed_tokens for consistency

The dtype inference in embed_video() was using embed_tokens.weight.dtype,
but embed_tokens can be missing/tied in some checkpoints (e.g., when
loading pi05-video models). This caused a RuntimeError:
"expected scalar type BFloat16 but found Float" because:

- embed_tokens was freshly initialized as float32 (missing from checkpoint)
- video_resampler layers were loaded as bfloat16 (from checkpoint)
- video_embeddings were cast to float32, then passed to bfloat16 layers

Fix: Use video_resampler.ln_kv.weight.dtype as the target dtype source,
since this is the exact layer that requires dtype consistency and is
always present when use_video_encoder=True.
This commit is contained in:
Michel Aractingi
2026-01-26 17:00:04 +00:00
parent be2267974a
commit 02dbaf22ee
@@ -826,15 +826,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
Returns:
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model.
PaliGemma's hidden dimension, in the same dtype as the video_resampler.
"""
if self.video_encoder is None:
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
device = video_frames.device
# Determine target dtype: match PaliGemma's language model dtype for consistency
paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype
# Determine target dtype: use the video_resampler's dtype for consistency
# Note: We use video_resampler (not embed_tokens) because embed_tokens may be
# tied/missing in some checkpoints, but video_resampler is always initialized
target_dtype = self.video_resampler.ln_kv.weight.dtype
# Move video encoder to the same device if needed
if next(self.video_encoder.parameters()).device != device:
@@ -869,8 +871,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
video_embeddings = video_outputs.last_hidden_state
# Convert to PaliGemma's dtype for consistency with the rest of the model
video_embeddings = video_embeddings.to(dtype=paligemma_dtype)
# Convert to target dtype for consistency with the video_resampler
video_embeddings = video_embeddings.to(dtype=target_dtype)
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
# This uses cross-attention from learnable queries to the video tokens