mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(videovla): improve PerceiverResampler and address code review issues
Key fixes for the PI05Video policy implementation: PerceiverResampler improvements: - Add residual connection (latents + attn_out) for better gradient flow - Add output LayerNorm after residual connection - Initialize latents with smaller variance (*0.02) for stability Bug fixes: - Replace expand() with repeat() in _preprocess_video to create copies instead of memory views, preventing potential in-place modification bugs - Fix dtype consistency in embed_video: use PaliGemma's dtype instead of input dtype for consistent processing throughout the pipeline - Add bfloat16/float16 support to resize_with_pad_torch PEFT improvements: - Remove state_proj from target modules (PI0-only, not in PI05) - Add video_proj and video_resampler to PEFT targets for fine-tuning Other improvements: - Add warning when use_video_encoder=True but no image features found - Add gradient checkpointing support for video encoder - Remove duplicate tokenizer_max_length definition in config - Add validation for video_num_latents and video_resampler_num_heads
This commit is contained in:
@@ -108,8 +108,6 @@ class PI05VideoConfig(PreTrainedConfig):
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -138,6 +136,12 @@ class PI05VideoConfig(PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
||||
)
|
||||
if self.video_num_latents < 1:
|
||||
raise ValueError(f"video_num_latents must be >= 1, got {self.video_num_latents}")
|
||||
if self.video_resampler_num_heads < 1:
|
||||
raise ValueError(
|
||||
f"video_resampler_num_heads must be >= 1, got {self.video_resampler_num_heads}"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
|
||||
@@ -197,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
elif images.dtype in (torch.float32, torch.float16, torch.bfloat16):
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
@@ -315,8 +315,8 @@ class PerceiverResampler(nn.Module):
|
||||
self.num_latents = num_latents
|
||||
self.dim = dim
|
||||
|
||||
# Learnable query tokens
|
||||
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||
# Learnable query tokens (initialized with small values for stability)
|
||||
self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02)
|
||||
|
||||
# Cross-attention layer
|
||||
self.attn = nn.MultiheadAttention(
|
||||
@@ -328,6 +328,8 @@ class PerceiverResampler(nn.Module):
|
||||
# Layer norms for queries and key-values
|
||||
self.ln_q = nn.LayerNorm(dim)
|
||||
self.ln_kv = nn.LayerNorm(dim)
|
||||
# Output layer norm (applied after residual connection)
|
||||
self.ln_out = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -347,7 +349,10 @@ class PerceiverResampler(nn.Module):
|
||||
kv = self.ln_kv(x)
|
||||
|
||||
# Cross-attention: queries attend to video tokens
|
||||
out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
|
||||
attn_out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
|
||||
|
||||
# Residual connection + output layer norm
|
||||
out = self.ln_out(latents + attn_out)
|
||||
|
||||
return out
|
||||
|
||||
@@ -693,6 +698,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
# Enable gradient checkpointing for video encoder if available and not frozen
|
||||
if self.video_encoder is not None and not self.config.freeze_video_encoder:
|
||||
if hasattr(self.video_encoder, "gradient_checkpointing_enable"):
|
||||
self.video_encoder.gradient_checkpointing_enable()
|
||||
logging.info("Enabled gradient checkpointing for video encoder")
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
@@ -701,6 +713,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
# Disable gradient checkpointing for video encoder if available
|
||||
if self.video_encoder is not None:
|
||||
if hasattr(self.video_encoder, "gradient_checkpointing_disable"):
|
||||
self.video_encoder.gradient_checkpointing_disable()
|
||||
logging.info("Disabled gradient checkpointing for video encoder")
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
@@ -807,13 +826,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
Returns:
|
||||
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
|
||||
PaliGemma's hidden dimension.
|
||||
PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model.
|
||||
"""
|
||||
if self.video_encoder is None:
|
||||
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
|
||||
|
||||
device = video_frames.device
|
||||
dtype = video_frames.dtype
|
||||
|
||||
# Determine target dtype: match PaliGemma's language model dtype for consistency
|
||||
paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype
|
||||
|
||||
# Move video encoder to the same device if needed
|
||||
if next(self.video_encoder.parameters()).device != device:
|
||||
@@ -848,8 +869,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
|
||||
video_embeddings = video_outputs.last_hidden_state
|
||||
|
||||
# Convert to working dtype
|
||||
video_embeddings = video_embeddings.to(dtype=dtype)
|
||||
# Convert to PaliGemma's dtype for consistency with the rest of the model
|
||||
video_embeddings = video_embeddings.to(dtype=paligemma_dtype)
|
||||
|
||||
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
|
||||
# This uses cross-attention from learnable queries to the video tokens
|
||||
@@ -1369,6 +1390,12 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
return next(iter(self.config.image_features.keys()))
|
||||
|
||||
# Warn if video encoder is enabled but no image features found
|
||||
logging.warning(
|
||||
"use_video_encoder=True but no image features found in config. "
|
||||
"Video encoding will be skipped. Either set video_encoder_camera_key "
|
||||
"or ensure image_features contains at least one camera."
|
||||
)
|
||||
return None
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
@@ -1496,7 +1523,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
||||
# Single frame [B, C, H, W] - expand to video by repeating
|
||||
B, C, H, W = img.shape
|
||||
if self.config.video_padding_mode == "repeat":
|
||||
video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W)
|
||||
# Use repeat() instead of expand() to create actual copies, not views
|
||||
# This prevents potential issues if downstream operations modify tensors in-place
|
||||
video_frames = img.unsqueeze(1).repeat(1, self.config.video_num_frames, 1, 1, 1)
|
||||
else: # zero padding
|
||||
video_frames = torch.zeros(
|
||||
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
|
||||
@@ -1513,8 +1542,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
||||
|
||||
if self.config.video_padding_mode == "repeat":
|
||||
# Repeat the first frame to fill missing frames at the beginning
|
||||
# Use repeat() instead of expand() to create actual copies, not views
|
||||
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
|
||||
padding = first_frame.expand(B, num_missing, C, H, W)
|
||||
padding = first_frame.repeat(1, num_missing, 1, 1, 1)
|
||||
video_frames = torch.cat([padding, video_frames], dim=1)
|
||||
else: # zero padding
|
||||
# Zero-pad at the beginning
|
||||
@@ -1625,11 +1655,21 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning.
|
||||
|
||||
Note: PI05 does NOT have state_proj (that's PI0 only). PI05 tokenizes state
|
||||
into the language prompt instead.
|
||||
"""
|
||||
# Core PI05 projections (no state_proj in PI05)
|
||||
core_projections = "action_in_proj|action_out_proj|time_mlp_in|time_mlp_out"
|
||||
|
||||
# Video-related modules (only present if use_video_encoder=True)
|
||||
video_projections = "video_proj|video_resampler\\..*"
|
||||
|
||||
# Combine all projections
|
||||
all_projections = f"{core_projections}|{video_projections}"
|
||||
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({all_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
|
||||
Reference in New Issue
Block a user