rename config param for multiple vision encoders

This commit is contained in:
Bryson Jones
2025-12-11 09:58:57 -08:00
parent 1f74982469
commit 51dfee43f4
3 changed files with 5 additions and 7 deletions
+1 -1
View File
@@ -157,7 +157,7 @@ The model supports two positional encoding methods for action sequences:
# Use separate vision encoder per camera # Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but # This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint. # be wary of increased VRAM footprint.
--policy.use_separate_encoder_per_camera=true --policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing # Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups --policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
@@ -71,7 +71,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
# Vision Encoder (CLIP) # Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default) image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
@@ -113,9 +113,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
"""Validate configuration parameters.""" """Validate configuration parameters."""
# Objective validation # Objective validation
if self.objective not in ["diffusion", "flow_matching"]: if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError( raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'"
)
# Transformer validation # Transformer validation
if self.hidden_dim <= 0: if self.hidden_dim <= 0:
@@ -246,7 +246,7 @@ class ObservationEncoder(nn.Module):
self.num_cameras = len(config.image_features) self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) self.camera_names = list(config.image_features.keys())
if config.use_separate_encoder_per_camera: if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList( self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names] [CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
) )
@@ -326,7 +326,7 @@ class ObservationEncoder(nn.Module):
if len(images.shape) == 5: if len(images.shape) == 5:
images = images.unsqueeze(1) images = images.unsqueeze(1)
if self.config.use_separate_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
camera_features = [] camera_features = []
for cam_idx in range(self.num_cameras): for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx] cam_images = images[:, :, cam_idx]