mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
expand the observation encoder to support differnt size encoders for vision and text
This commit is contained in:
@@ -232,20 +232,18 @@ class DinoV3EncoderConfig(VisionEncoderConfig):
|
||||
DinoV3 is a self-supervised Vision Transformer trained by Meta.
|
||||
CLS token usage and spatial feature extraction are handled automatically.
|
||||
|
||||
Available backbones:
|
||||
Any timm model with "dinov3" in the name can be used. Examples:
|
||||
- vit_base_patch16_dinov3.lvd1689m (768 dims)
|
||||
- vit_large_patch14_dinov3.lvd142m (1024 dims)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate backbone name
|
||||
valid_backbones = [
|
||||
"vit_base_patch16_dinov3.lvd1689m",
|
||||
]
|
||||
if self.backbone not in valid_backbones:
|
||||
raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'")
|
||||
# Validate that backbone name contains "dinov3" to ensure correct encoder type
|
||||
if "dinov3" not in self.backbone.lower():
|
||||
raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'")
|
||||
|
||||
|
||||
@VisionEncoderConfig.register_subclass("clip")
|
||||
@@ -258,17 +256,18 @@ class CLIPVisionEncoderConfig(VisionEncoderConfig):
|
||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||
by setting resize_shape and crop_shape.
|
||||
|
||||
Available backbones:
|
||||
Any timm model with "clip" in the name can be used. Examples:
|
||||
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
|
||||
- vit_large_patch14_clip_224.openai (1024 dims)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_clip_224.openai"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate backbone name
|
||||
# Validate that backbone name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.backbone.lower():
|
||||
raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'")
|
||||
raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -294,14 +293,19 @@ class CLIPTextEncoderConfig(TextEncoderConfig):
|
||||
used to condition the policy. The text embeddings are processed by
|
||||
a learnable projection layer before being concatenated into the
|
||||
conditioning vector.
|
||||
|
||||
Any HuggingFace CLIP model can be used. Examples:
|
||||
- openai/clip-vit-base-patch16 (default)
|
||||
- openai/clip-vit-large-patch14
|
||||
"""
|
||||
|
||||
model: str = "openai/clip-vit-base-patch16"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate that model name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.model.lower():
|
||||
raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'")
|
||||
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -123,6 +123,11 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
|
||||
|
||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
"""Create a vision encoder from config.
|
||||
|
||||
Supports any timm model with "clip" or "dinov3" in the backbone name.
|
||||
The encoder type is automatically detected based on the backbone name.
|
||||
"""
|
||||
backbone_name = config.backbone.lower()
|
||||
|
||||
# Check if it's a CLIP model
|
||||
@@ -136,7 +141,7 @@ def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vision backbone: {config.backbone}. "
|
||||
f"Currently supported: DinoV3 models and CLIP models"
|
||||
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
|
||||
)
|
||||
|
||||
|
||||
@@ -148,26 +153,19 @@ VISION_ENCODER_REGISTRY: dict[str, type] = {
|
||||
|
||||
|
||||
def register_vision_encoder(name: str, encoder_class: type):
|
||||
"""Register a new vision encoder type.
|
||||
|
||||
Args:
|
||||
name: Identifier for the encoder type
|
||||
encoder_class: Class implementing BaseVisionEncoder interface
|
||||
"""
|
||||
"""Register a new vision encoder type."""
|
||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
||||
|
||||
|
||||
def get_registered_encoders() -> dict[str, type]:
|
||||
"""Get all registered vision encoder types.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping encoder names to classes
|
||||
"""
|
||||
"""Get all registered vision encoder types."""
|
||||
return VISION_ENCODER_REGISTRY.copy()
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and learnable projection."""
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user