remove dino vision encoder and simplify text and vision encoders by removing inheritance structure

This commit is contained in:
Bryson Jones
2025-12-10 11:09:37 -08:00
parent 55e19ff9a7
commit adabb37af6
4 changed files with 42 additions and 186 deletions
+1 -1
View File
@@ -122,7 +122,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
@@ -196,16 +196,27 @@ class TransformerConfig:
@dataclass
class VisionEncoderConfig(draccus.ChoiceRegistry):
"""Base configuration for vision encoders.
class VisionEncoderConfig:
"""Configuration for CLIP vision encoder.
Uses CLIPVisionModel from transformers library.
CLS token usage is handled automatically.
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
All image preprocessing is centralized here:
1. Resize (optional) - resize images to target resolution
2. Crop (optional) - crop after resize, must be smaller than resize_shape
3. Random crop - whether to use random cropping during training
Any CLIP model from transformers can be used. Examples:
- openai/clip-vit-base-patch16 (default, 768 dims)
- openai/clip-vit-large-patch14 (1024 dims)
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
"""
use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders
model_name: str = "openai/clip-vit-base-patch16"
use_separate_encoder_per_camera: bool = False
# Learning rate multiplier for vision encoder parameters
# Vision encoder learning rate = optimizer_lr * lr_multiplier
@@ -217,6 +228,12 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
crop_is_random: bool = True
def __post_init__(self):
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model_name.lower():
raise ValueError(
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
)
if (
self.resize_shape
and self.crop_shape
@@ -228,70 +245,9 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
)
@VisionEncoderConfig.register_subclass("dinov3")
@dataclass
class DinoV3EncoderConfig(VisionEncoderConfig):
"""DinoV3 vision encoder configuration.
DinoV3 is a self-supervised Vision Transformer trained by Meta.
CLS token usage and spatial feature extraction are handled automatically.
Any timm model with "dinov3" in the name can be used. Examples:
- vit_base_patch16_dinov3.lvd1689m (768 dims)
- vit_large_patch14_dinov3.lvd142m (1024 dims)
"""
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
def __post_init__(self):
super().__post_init__()
# Validate that backbone name contains "dinov3" to ensure correct encoder type
if "dinov3" not in self.backbone.lower():
raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'")
@VisionEncoderConfig.register_subclass("clip")
@dataclass
class CLIPVisionEncoderConfig(VisionEncoderConfig):
"""CLIP vision encoder configuration.
CLIP is a vision-language model trained by OpenAI.
CLS token usage is handled automatically.
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
Any timm model with "clip" in the name can be used. Examples:
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
- vit_large_patch14_clip_224.openai (1024 dims)
"""
backbone: str = "vit_base_patch16_clip_224.openai"
def __post_init__(self):
super().__post_init__()
# Validate that backbone name contains "clip" to ensure correct encoder type
if "clip" not in self.backbone.lower():
raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'")
@dataclass
class TextEncoderConfig(draccus.ChoiceRegistry):
"""Base configuration for text encoders.
If a text encoder is set in ObservationEncoderConfig, text conditioning
is automatically enabled.
"""
pass
def __post_init__(self):
pass
@TextEncoderConfig.register_subclass("clip")
@dataclass
class CLIPTextEncoderConfig(TextEncoderConfig):
"""CLIP text encoder for task conditioning.
class TextEncoderConfig:
"""Configuration for CLIP text encoder.
Uses CLIP's text encoder to embed task descriptions, which are then
used to condition the policy. The text embeddings are processed by
@@ -306,7 +262,6 @@ class CLIPTextEncoderConfig(TextEncoderConfig):
model: str = "openai/clip-vit-base-patch16"
def __post_init__(self):
super().__post_init__()
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model.lower():
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
@@ -317,11 +272,11 @@ class ObservationEncoderConfig:
"""Top-level configuration for observation encoding.
This config combines:
- Vision encoding (required): DinoV3 or CLIP vision encoder
- Vision encoding (required): CLIP vision encoder from transformers
"""
vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig)
text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig)
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
@PreTrainedConfig.register_subclass("multi_task_dit")
@@ -19,99 +19,44 @@
Handles vision encoding, text encoding, robot state, and environment state.
"""
from abc import ABC, abstractmethod
import einops
import timm
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class BaseVisionEncoder(ABC):
"""Abstract base class for vision encoders."""
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
pass
@abstractmethod
def get_output_shape(self) -> tuple:
"""Get the output shape (C', H', W')."""
pass
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
"""DinoV3 vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0,
)
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPEncoder(nn.Module, BaseVisionEncoder):
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, config):
def __init__(self, model_name: str):
super().__init__()
self.config = config
self.model_name = config.backbone
self.model_name = model_name
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0, # Remove classification head, we want features
)
# Load CLIP vision model from transformers
self.model = CLIPVisionModel.from_pretrained(self.model_name)
# CLIP models have 1 CLS token (no register tokens like DinoV3)
# CLIP models have 1 CLS token
self.num_non_spatial_tokens = 1
# Get embed_dim from model config
self.embed_dim = self.model.embed_dim
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token.
Preprocessing (resize, crop) is handled by ObservationEncoder
"""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Extract features using CLIPVisionModel
# Input: (B, C, H, W) - already preprocessed
outputs = self.model(pixel_values=x, output_hidden_states=False)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
# Extract CLS token from last_hidden_state (first token)
# last_hidden_state shape: (B, sequence_length, hidden_size)
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
@@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
return (self.embed_dim, 1, 1)
def create_vision_encoder(config) -> BaseVisionEncoder:
"""Create a vision encoder from config.
Supports any timm model with "clip" or "dinov3" in the backbone name.
The encoder type is automatically detected based on the backbone name.
"""
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
if "clip" in backbone_name:
return CLIPEncoder(config)
# Check if it's a DinoV3 model
elif "dinov3" in backbone_name:
return DinoV3Encoder(config)
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
)
# Registry for easy extension
VISION_ENCODER_REGISTRY: dict[str, type] = {
"dinov3": DinoV3Encoder,
"clip": CLIPEncoder,
}
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type."""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types."""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
@@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module):
if vision_config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[create_vision_encoder(vision_config) for _ in self.camera_names]
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = create_vision_encoder(vision_config)
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
self.vision_encoders = None
else:
self.vision_encoder = None
@@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module):
self.do_crop = False
def _setup_vector_output(self):
"""Setup for vector output."""
total_dim = 0
# Vision features - get CLS token feature dimension
@@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module):
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
# Expand across temporal dimension to match other features
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
print("Text features shape after unsqueeze and expand:", text_features.shape)
conditioning_feats.append(text_features)
for vec in conditioning_feats:
print(f"Conditioning feature shape: {vec.shape}")
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)