fix(videovla): improve PerceiverResampler and address code review issues

Key fixes for the PI05Video policy implementation:

PerceiverResampler improvements:
- Add residual connection (latents + attn_out) for better gradient flow
- Add output LayerNorm after residual connection
- Initialize latents with smaller variance (*0.02) for stability

Bug fixes:
- Replace expand() with repeat() in _preprocess_video to create copies
  instead of memory views, preventing potential in-place modification bugs
- Fix dtype consistency in embed_video: use PaliGemma's dtype instead
  of input dtype for consistent processing throughout the pipeline
- Add bfloat16/float16 support to resize_with_pad_torch

PEFT improvements:
- Remove state_proj from target modules (PI0-only, not in PI05)
- Add video_proj and video_resampler to PEFT targets for fine-tuning

Other improvements:
- Add warning when use_video_encoder=True but no image features found
- Add gradient checkpointing support for video encoder
- Remove duplicate tokenizer_max_length definition in config
- Add validation for video_num_latents and video_resampler_num_heads
This commit is contained in:
Michel Aractingi
2026-01-26 11:16:43 +01:00
parent 18bba97cd6
commit be2267974a
2 changed files with 61 additions and 17 deletions
@@ -108,8 +108,6 @@ class PI05VideoConfig(PreTrainedConfig):
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
@@ -138,6 +136,12 @@ class PI05VideoConfig(PreTrainedConfig):
raise ValueError(
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
)
if self.video_num_latents < 1:
raise ValueError(f"video_num_latents must be >= 1, got {self.video_num_latents}")
if self.video_resampler_num_heads < 1:
raise ValueError(
f"video_resampler_num_heads must be >= 1, got {self.video_resampler_num_heads}"
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
+55 -15
View File
@@ -197,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
elif images.dtype in (torch.float32, torch.float16, torch.bfloat16):
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -315,8 +315,8 @@ class PerceiverResampler(nn.Module):
self.num_latents = num_latents
self.dim = dim
# Learnable query tokens
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# Learnable query tokens (initialized with small values for stability)
self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02)
# Cross-attention layer
self.attn = nn.MultiheadAttention(
@@ -328,6 +328,8 @@ class PerceiverResampler(nn.Module):
# Layer norms for queries and key-values
self.ln_q = nn.LayerNorm(dim)
self.ln_kv = nn.LayerNorm(dim)
# Output layer norm (applied after residual connection)
self.ln_out = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
@@ -347,7 +349,10 @@ class PerceiverResampler(nn.Module):
kv = self.ln_kv(x)
# Cross-attention: queries attend to video tokens
out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
attn_out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
# Residual connection + output layer norm
out = self.ln_out(latents + attn_out)
return out
@@ -693,6 +698,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
# Enable gradient checkpointing for video encoder if available and not frozen
if self.video_encoder is not None and not self.config.freeze_video_encoder:
if hasattr(self.video_encoder, "gradient_checkpointing_enable"):
self.video_encoder.gradient_checkpointing_enable()
logging.info("Enabled gradient checkpointing for video encoder")
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
@@ -701,6 +713,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
# Disable gradient checkpointing for video encoder if available
if self.video_encoder is not None:
if hasattr(self.video_encoder, "gradient_checkpointing_disable"):
self.video_encoder.gradient_checkpointing_disable()
logging.info("Disabled gradient checkpointing for video encoder")
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
@@ -807,13 +826,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
Returns:
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
PaliGemma's hidden dimension.
PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model.
"""
if self.video_encoder is None:
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
device = video_frames.device
dtype = video_frames.dtype
# Determine target dtype: match PaliGemma's language model dtype for consistency
paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype
# Move video encoder to the same device if needed
if next(self.video_encoder.parameters()).device != device:
@@ -848,8 +869,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
video_embeddings = video_outputs.last_hidden_state
# Convert to working dtype
video_embeddings = video_embeddings.to(dtype=dtype)
# Convert to PaliGemma's dtype for consistency with the rest of the model
video_embeddings = video_embeddings.to(dtype=paligemma_dtype)
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
# This uses cross-attention from learnable queries to the video tokens
@@ -1369,6 +1390,12 @@ class PI05VideoPolicy(PreTrainedPolicy):
if self.config.image_features:
return next(iter(self.config.image_features.keys()))
# Warn if video encoder is enabled but no image features found
logging.warning(
"use_video_encoder=True but no image features found in config. "
"Video encoding will be skipped. Either set video_encoder_camera_key "
"or ensure image_features contains at least one camera."
)
return None
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
@@ -1496,7 +1523,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
# Single frame [B, C, H, W] - expand to video by repeating
B, C, H, W = img.shape
if self.config.video_padding_mode == "repeat":
video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W)
# Use repeat() instead of expand() to create actual copies, not views
# This prevents potential issues if downstream operations modify tensors in-place
video_frames = img.unsqueeze(1).repeat(1, self.config.video_num_frames, 1, 1, 1)
else: # zero padding
video_frames = torch.zeros(
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
@@ -1513,8 +1542,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
if self.config.video_padding_mode == "repeat":
# Repeat the first frame to fill missing frames at the beginning
# Use repeat() instead of expand() to create actual copies, not views
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
padding = first_frame.expand(B, num_missing, C, H, W)
padding = first_frame.repeat(1, num_missing, 1, 1, 1)
video_frames = torch.cat([padding, video_frames], dim=1)
else: # zero padding
# Zero-pad at the beginning
@@ -1625,11 +1655,21 @@ class PI05VideoPolicy(PreTrainedPolicy):
return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0.5 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
"""Return default PEFT target modules for PI0.5 fine-tuning.
Note: PI05 does NOT have state_proj (that's PI0 only). PI05 tokenizes state
into the language prompt instead.
"""
# Core PI05 projections (no state_proj in PI05)
core_projections = "action_in_proj|action_out_proj|time_mlp_in|time_mlp_out"
# Video-related modules (only present if use_video_encoder=True)
video_projections = "video_proj|video_resampler\\..*"
# Combine all projections
all_projections = f"{core_projections}|{video_projections}"
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({all_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],