mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix(videovla): improve PerceiverResampler and address code review issues
Key fixes for the PI05Video policy implementation: PerceiverResampler improvements: - Add residual connection (latents + attn_out) for better gradient flow - Add output LayerNorm after residual connection - Initialize latents with smaller variance (*0.02) for stability Bug fixes: - Replace expand() with repeat() in _preprocess_video to create copies instead of memory views, preventing potential in-place modification bugs - Fix dtype consistency in embed_video: use PaliGemma's dtype instead of input dtype for consistent processing throughout the pipeline - Add bfloat16/float16 support to resize_with_pad_torch PEFT improvements: - Remove state_proj from target modules (PI0-only, not in PI05) - Add video_proj and video_resampler to PEFT targets for fine-tuning Other improvements: - Add warning when use_video_encoder=True but no image features found - Add gradient checkpointing support for video encoder - Remove duplicate tokenizer_max_length definition in config - Add validation for video_num_latents and video_resampler_num_heads
This commit is contained in:
@@ -108,8 +108,6 @@ class PI05VideoConfig(PreTrainedConfig):
|
|||||||
scheduler_decay_steps: int = 30_000
|
scheduler_decay_steps: int = 30_000
|
||||||
scheduler_decay_lr: float = 2.5e-6
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
@@ -138,6 +136,12 @@ class PI05VideoConfig(PreTrainedConfig):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
||||||
)
|
)
|
||||||
|
if self.video_num_latents < 1:
|
||||||
|
raise ValueError(f"video_num_latents must be >= 1, got {self.video_num_latents}")
|
||||||
|
if self.video_resampler_num_heads < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"video_resampler_num_heads must be >= 1, got {self.video_resampler_num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate and set up input/output features."""
|
"""Validate and set up input/output features."""
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
# Handle dtype-specific clipping
|
# Handle dtype-specific clipping
|
||||||
if images.dtype == torch.uint8:
|
if images.dtype == torch.uint8:
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||||
elif images.dtype == torch.float32:
|
elif images.dtype in (torch.float32, torch.float16, torch.bfloat16):
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||||
@@ -315,8 +315,8 @@ class PerceiverResampler(nn.Module):
|
|||||||
self.num_latents = num_latents
|
self.num_latents = num_latents
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
# Learnable query tokens
|
# Learnable query tokens (initialized with small values for stability)
|
||||||
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02)
|
||||||
|
|
||||||
# Cross-attention layer
|
# Cross-attention layer
|
||||||
self.attn = nn.MultiheadAttention(
|
self.attn = nn.MultiheadAttention(
|
||||||
@@ -328,6 +328,8 @@ class PerceiverResampler(nn.Module):
|
|||||||
# Layer norms for queries and key-values
|
# Layer norms for queries and key-values
|
||||||
self.ln_q = nn.LayerNorm(dim)
|
self.ln_q = nn.LayerNorm(dim)
|
||||||
self.ln_kv = nn.LayerNorm(dim)
|
self.ln_kv = nn.LayerNorm(dim)
|
||||||
|
# Output layer norm (applied after residual connection)
|
||||||
|
self.ln_out = nn.LayerNorm(dim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -347,7 +349,10 @@ class PerceiverResampler(nn.Module):
|
|||||||
kv = self.ln_kv(x)
|
kv = self.ln_kv(x)
|
||||||
|
|
||||||
# Cross-attention: queries attend to video tokens
|
# Cross-attention: queries attend to video tokens
|
||||||
out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
|
attn_out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
|
||||||
|
|
||||||
|
# Residual connection + output layer norm
|
||||||
|
out = self.ln_out(latents + attn_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -693,6 +698,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||||
|
|
||||||
|
# Enable gradient checkpointing for video encoder if available and not frozen
|
||||||
|
if self.video_encoder is not None and not self.config.freeze_video_encoder:
|
||||||
|
if hasattr(self.video_encoder, "gradient_checkpointing_enable"):
|
||||||
|
self.video_encoder.gradient_checkpointing_enable()
|
||||||
|
logging.info("Enabled gradient checkpointing for video encoder")
|
||||||
|
|
||||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
@@ -701,6 +713,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Disable gradient checkpointing for video encoder if available
|
||||||
|
if self.video_encoder is not None:
|
||||||
|
if hasattr(self.video_encoder, "gradient_checkpointing_disable"):
|
||||||
|
self.video_encoder.gradient_checkpointing_disable()
|
||||||
|
logging.info("Disabled gradient checkpointing for video encoder")
|
||||||
|
|
||||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
def _rtc_enabled(self):
|
def _rtc_enabled(self):
|
||||||
@@ -807,13 +826,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
|
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
|
||||||
PaliGemma's hidden dimension.
|
PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model.
|
||||||
"""
|
"""
|
||||||
if self.video_encoder is None:
|
if self.video_encoder is None:
|
||||||
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
|
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
|
||||||
|
|
||||||
device = video_frames.device
|
device = video_frames.device
|
||||||
dtype = video_frames.dtype
|
|
||||||
|
# Determine target dtype: match PaliGemma's language model dtype for consistency
|
||||||
|
paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype
|
||||||
|
|
||||||
# Move video encoder to the same device if needed
|
# Move video encoder to the same device if needed
|
||||||
if next(self.video_encoder.parameters()).device != device:
|
if next(self.video_encoder.parameters()).device != device:
|
||||||
@@ -848,8 +869,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
|
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
|
||||||
video_embeddings = video_outputs.last_hidden_state
|
video_embeddings = video_outputs.last_hidden_state
|
||||||
|
|
||||||
# Convert to working dtype
|
# Convert to PaliGemma's dtype for consistency with the rest of the model
|
||||||
video_embeddings = video_embeddings.to(dtype=dtype)
|
video_embeddings = video_embeddings.to(dtype=paligemma_dtype)
|
||||||
|
|
||||||
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
|
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
|
||||||
# This uses cross-attention from learnable queries to the video tokens
|
# This uses cross-attention from learnable queries to the video tokens
|
||||||
@@ -1369,6 +1390,12 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
|||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
return next(iter(self.config.image_features.keys()))
|
return next(iter(self.config.image_features.keys()))
|
||||||
|
|
||||||
|
# Warn if video encoder is enabled but no image features found
|
||||||
|
logging.warning(
|
||||||
|
"use_video_encoder=True but no image features found in config. "
|
||||||
|
"Video encoding will be skipped. Either set video_encoder_camera_key "
|
||||||
|
"or ensure image_features contains at least one camera."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||||
@@ -1496,7 +1523,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
|||||||
# Single frame [B, C, H, W] - expand to video by repeating
|
# Single frame [B, C, H, W] - expand to video by repeating
|
||||||
B, C, H, W = img.shape
|
B, C, H, W = img.shape
|
||||||
if self.config.video_padding_mode == "repeat":
|
if self.config.video_padding_mode == "repeat":
|
||||||
video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W)
|
# Use repeat() instead of expand() to create actual copies, not views
|
||||||
|
# This prevents potential issues if downstream operations modify tensors in-place
|
||||||
|
video_frames = img.unsqueeze(1).repeat(1, self.config.video_num_frames, 1, 1, 1)
|
||||||
else: # zero padding
|
else: # zero padding
|
||||||
video_frames = torch.zeros(
|
video_frames = torch.zeros(
|
||||||
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
|
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
|
||||||
@@ -1513,8 +1542,9 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if self.config.video_padding_mode == "repeat":
|
if self.config.video_padding_mode == "repeat":
|
||||||
# Repeat the first frame to fill missing frames at the beginning
|
# Repeat the first frame to fill missing frames at the beginning
|
||||||
|
# Use repeat() instead of expand() to create actual copies, not views
|
||||||
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
|
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
|
||||||
padding = first_frame.expand(B, num_missing, C, H, W)
|
padding = first_frame.repeat(1, num_missing, 1, 1, 1)
|
||||||
video_frames = torch.cat([padding, video_frames], dim=1)
|
video_frames = torch.cat([padding, video_frames], dim=1)
|
||||||
else: # zero padding
|
else: # zero padding
|
||||||
# Zero-pad at the beginning
|
# Zero-pad at the beginning
|
||||||
@@ -1625,11 +1655,21 @@ class PI05VideoPolicy(PreTrainedPolicy):
|
|||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
"""Return default PEFT target modules for PI0.5 fine-tuning.
|
||||||
common_projections = (
|
|
||||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
Note: PI05 does NOT have state_proj (that's PI0 only). PI05 tokenizes state
|
||||||
)
|
into the language prompt instead.
|
||||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
"""
|
||||||
|
# Core PI05 projections (no state_proj in PI05)
|
||||||
|
core_projections = "action_in_proj|action_out_proj|time_mlp_in|time_mlp_out"
|
||||||
|
|
||||||
|
# Video-related modules (only present if use_video_encoder=True)
|
||||||
|
video_projections = "video_proj|video_resampler\\..*"
|
||||||
|
|
||||||
|
# Combine all projections
|
||||||
|
all_projections = f"{core_projections}|{video_projections}"
|
||||||
|
|
||||||
|
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({all_projections}))"
|
||||||
return {
|
return {
|
||||||
"target_modules": target_modules,
|
"target_modules": target_modules,
|
||||||
"modules_to_save": [],
|
"modules_to_save": [],
|
||||||
|
|||||||
Reference in New Issue
Block a user