expand the observation encoder to support differnt size encoders for vision and text

This commit is contained in:
Bryson Jones
2025-11-21 14:31:35 -08:00
parent ab97d5c019
commit 8b9fada80f
2 changed files with 26 additions and 24 deletions
@@ -232,20 +232,18 @@ class DinoV3EncoderConfig(VisionEncoderConfig):
DinoV3 is a self-supervised Vision Transformer trained by Meta.
CLS token usage and spatial feature extraction are handled automatically.
Available backbones:
Any timm model with "dinov3" in the name can be used. Examples:
- vit_base_patch16_dinov3.lvd1689m (768 dims)
- vit_large_patch14_dinov3.lvd142m (1024 dims)
"""
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
def __post_init__(self):
super().__post_init__()
# Validate backbone name
valid_backbones = [
"vit_base_patch16_dinov3.lvd1689m",
]
if self.backbone not in valid_backbones:
raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'")
# Validate that backbone name contains "dinov3" to ensure correct encoder type
if "dinov3" not in self.backbone.lower():
raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'")
@VisionEncoderConfig.register_subclass("clip")
@@ -258,17 +256,18 @@ class CLIPVisionEncoderConfig(VisionEncoderConfig):
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
Available backbones:
Any timm model with "clip" in the name can be used. Examples:
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
- vit_large_patch14_clip_224.openai (1024 dims)
"""
backbone: str = "vit_base_patch16_clip_224.openai"
def __post_init__(self):
super().__post_init__()
# Validate backbone name
# Validate that backbone name contains "clip" to ensure correct encoder type
if "clip" not in self.backbone.lower():
raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'")
raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'")
@dataclass
@@ -294,14 +293,19 @@ class CLIPTextEncoderConfig(TextEncoderConfig):
used to condition the policy. The text embeddings are processed by
a learnable projection layer before being concatenated into the
conditioning vector.
Any HuggingFace CLIP model can be used. Examples:
- openai/clip-vit-base-patch16 (default)
- openai/clip-vit-large-patch14
"""
model: str = "openai/clip-vit-base-patch16"
def __post_init__(self):
super().__post_init__()
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model.lower():
raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'")
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
@dataclass
@@ -123,6 +123,11 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
def create_vision_encoder(config) -> BaseVisionEncoder:
"""Create a vision encoder from config.
Supports any timm model with "clip" or "dinov3" in the backbone name.
The encoder type is automatically detected based on the backbone name.
"""
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
@@ -136,7 +141,7 @@ def create_vision_encoder(config) -> BaseVisionEncoder:
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: DinoV3 models and CLIP models"
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
)
@@ -148,26 +153,19 @@ VISION_ENCODER_REGISTRY: dict[str, type] = {
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type.
Args:
name: Identifier for the encoder type
encoder_class: Class implementing BaseVisionEncoder interface
"""
"""Register a new vision encoder type."""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types.
Returns:
Dictionary mapping encoder names to classes
"""
"""Get all registered vision encoder types."""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and learnable projection."""
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()