rename config param for multiple vision encoders

This commit is contained in:
Bryson Jones
2025-12-11 09:58:57 -08:00
parent 1f74982469
commit 51dfee43f4
3 changed files with 5 additions and 7 deletions
+1 -1
View File
@@ -157,7 +157,7 @@ The model supports two positional encoding methods for action sequences:
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_encoder_per_camera=true
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
@@ -71,7 +71,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
@@ -113,9 +113,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(
f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'"
)
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
@@ -246,7 +246,7 @@ class ObservationEncoder(nn.Module):
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_encoder_per_camera:
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
@@ -326,7 +326,7 @@ class ObservationEncoder(nn.Module):
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_encoder_per_camera:
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]