mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(videovla): use video_resampler dtype instead of embed_tokens for consistency
The dtype inference in embed_video() was using embed_tokens.weight.dtype, but embed_tokens can be missing/tied in some checkpoints (e.g., when loading pi05-video models). This caused a RuntimeError: "expected scalar type BFloat16 but found Float" because: - embed_tokens was freshly initialized as float32 (missing from checkpoint) - video_resampler layers were loaded as bfloat16 (from checkpoint) - video_embeddings were cast to float32, then passed to bfloat16 layers Fix: Use video_resampler.ln_kv.weight.dtype as the target dtype source, since this is the exact layer that requires dtype consistency and is always present when use_video_encoder=True.
This commit is contained in:
@@ -826,15 +826,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
Returns:
|
||||
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
|
||||
PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model.
|
||||
PaliGemma's hidden dimension, in the same dtype as the video_resampler.
|
||||
"""
|
||||
if self.video_encoder is None:
|
||||
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
|
||||
|
||||
device = video_frames.device
|
||||
|
||||
# Determine target dtype: match PaliGemma's language model dtype for consistency
|
||||
paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype
|
||||
# Determine target dtype: use the video_resampler's dtype for consistency
|
||||
# Note: We use video_resampler (not embed_tokens) because embed_tokens may be
|
||||
# tied/missing in some checkpoints, but video_resampler is always initialized
|
||||
target_dtype = self.video_resampler.ln_kv.weight.dtype
|
||||
|
||||
# Move video encoder to the same device if needed
|
||||
if next(self.video_encoder.parameters()).device != device:
|
||||
@@ -869,8 +871,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
|
||||
video_embeddings = video_outputs.last_hidden_state
|
||||
|
||||
# Convert to PaliGemma's dtype for consistency with the rest of the model
|
||||
video_embeddings = video_embeddings.to(dtype=paligemma_dtype)
|
||||
# Convert to target dtype for consistency with the video_resampler
|
||||
video_embeddings = video_embeddings.to(dtype=target_dtype)
|
||||
|
||||
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
|
||||
# This uses cross-attention from learnable queries to the video tokens
|
||||
|
||||
Reference in New Issue
Block a user