mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
rename config param for multiple vision encoders
This commit is contained in:
@@ -157,7 +157,7 @@ The model supports two positional encoding methods for action sequences:
|
||||
# Use separate vision encoder per camera
|
||||
# This may be useful when cameras have significantly different characteristics, but
|
||||
# be wary of increased VRAM footprint.
|
||||
--policy.use_separate_encoder_per_camera=true
|
||||
--policy.use_separate_rgb_encoder_per_camera=true
|
||||
|
||||
# Image preprocessing
|
||||
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
|
||||
|
||||
@@ -71,7 +71,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
|
||||
# Vision Encoder (CLIP)
|
||||
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||
@@ -113,9 +113,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""Validate configuration parameters."""
|
||||
# Objective validation
|
||||
if self.objective not in ["diffusion", "flow_matching"]:
|
||||
raise ValueError(
|
||||
f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'"
|
||||
)
|
||||
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
|
||||
|
||||
# Transformer validation
|
||||
if self.hidden_dim <= 0:
|
||||
|
||||
@@ -246,7 +246,7 @@ class ObservationEncoder(nn.Module):
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if config.use_separate_encoder_per_camera:
|
||||
if config.use_separate_rgb_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
@@ -326,7 +326,7 @@ class ObservationEncoder(nn.Module):
|
||||
if len(images.shape) == 5:
|
||||
images = images.unsqueeze(1)
|
||||
|
||||
if self.config.use_separate_encoder_per_camera:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
camera_features = []
|
||||
for cam_idx in range(self.num_cameras):
|
||||
cam_images = images[:, :, cam_idx]
|
||||
|
||||
Reference in New Issue
Block a user