mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
remove dino vision encoder and simplify text and vision encoders by removing inheritance structure
This commit is contained in:
+1
-1
@@ -122,7 +122,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
# Policies
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
|
||||
@@ -196,16 +196,27 @@ class TransformerConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionEncoderConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for vision encoders.
|
||||
class VisionEncoderConfig:
|
||||
"""Configuration for CLIP vision encoder.
|
||||
|
||||
Uses CLIPVisionModel from transformers library.
|
||||
CLS token usage is handled automatically.
|
||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||
by setting resize_shape and crop_shape.
|
||||
|
||||
All image preprocessing is centralized here:
|
||||
1. Resize (optional) - resize images to target resolution
|
||||
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
||||
3. Random crop - whether to use random cropping during training
|
||||
|
||||
Any CLIP model from transformers can be used. Examples:
|
||||
- openai/clip-vit-base-patch16 (default, 768 dims)
|
||||
- openai/clip-vit-large-patch14 (1024 dims)
|
||||
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
|
||||
"""
|
||||
|
||||
use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders
|
||||
model_name: str = "openai/clip-vit-base-patch16"
|
||||
use_separate_encoder_per_camera: bool = False
|
||||
|
||||
# Learning rate multiplier for vision encoder parameters
|
||||
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
||||
@@ -217,6 +228,12 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
|
||||
crop_is_random: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate that model name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.model_name.lower():
|
||||
raise ValueError(
|
||||
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
|
||||
)
|
||||
|
||||
if (
|
||||
self.resize_shape
|
||||
and self.crop_shape
|
||||
@@ -228,70 +245,9 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
|
||||
)
|
||||
|
||||
|
||||
@VisionEncoderConfig.register_subclass("dinov3")
|
||||
@dataclass
|
||||
class DinoV3EncoderConfig(VisionEncoderConfig):
|
||||
"""DinoV3 vision encoder configuration.
|
||||
|
||||
DinoV3 is a self-supervised Vision Transformer trained by Meta.
|
||||
CLS token usage and spatial feature extraction are handled automatically.
|
||||
|
||||
Any timm model with "dinov3" in the name can be used. Examples:
|
||||
- vit_base_patch16_dinov3.lvd1689m (768 dims)
|
||||
- vit_large_patch14_dinov3.lvd142m (1024 dims)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate that backbone name contains "dinov3" to ensure correct encoder type
|
||||
if "dinov3" not in self.backbone.lower():
|
||||
raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'")
|
||||
|
||||
|
||||
@VisionEncoderConfig.register_subclass("clip")
|
||||
@dataclass
|
||||
class CLIPVisionEncoderConfig(VisionEncoderConfig):
|
||||
"""CLIP vision encoder configuration.
|
||||
|
||||
CLIP is a vision-language model trained by OpenAI.
|
||||
CLS token usage is handled automatically.
|
||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||
by setting resize_shape and crop_shape.
|
||||
|
||||
Any timm model with "clip" in the name can be used. Examples:
|
||||
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
|
||||
- vit_large_patch14_clip_224.openai (1024 dims)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_clip_224.openai"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate that backbone name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.backbone.lower():
|
||||
raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextEncoderConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for text encoders.
|
||||
|
||||
If a text encoder is set in ObservationEncoderConfig, text conditioning
|
||||
is automatically enabled.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@TextEncoderConfig.register_subclass("clip")
|
||||
@dataclass
|
||||
class CLIPTextEncoderConfig(TextEncoderConfig):
|
||||
"""CLIP text encoder for task conditioning.
|
||||
class TextEncoderConfig:
|
||||
"""Configuration for CLIP text encoder.
|
||||
|
||||
Uses CLIP's text encoder to embed task descriptions, which are then
|
||||
used to condition the policy. The text embeddings are processed by
|
||||
@@ -306,7 +262,6 @@ class CLIPTextEncoderConfig(TextEncoderConfig):
|
||||
model: str = "openai/clip-vit-base-patch16"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate that model name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.model.lower():
|
||||
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
||||
@@ -317,11 +272,11 @@ class ObservationEncoderConfig:
|
||||
"""Top-level configuration for observation encoding.
|
||||
|
||||
This config combines:
|
||||
- Vision encoding (required): DinoV3 or CLIP vision encoder
|
||||
- Vision encoding (required): CLIP vision encoder from transformers
|
||||
"""
|
||||
|
||||
vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig)
|
||||
text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig)
|
||||
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
|
||||
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
|
||||
@@ -19,99 +19,44 @@
|
||||
Handles vision encoding, text encoding, robot state, and environment state.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import einops
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class BaseVisionEncoder(ABC):
|
||||
"""Abstract base class for vision encoders."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_output_shape(self) -> tuple:
|
||||
"""Get the output shape (C', H', W')."""
|
||||
pass
|
||||
|
||||
|
||||
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
|
||||
"""DinoV3 vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
)
|
||||
|
||||
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
|
||||
self.embed_dim = self.model.embed_dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
||||
return cls_features
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
self.model_name = model_name
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0, # Remove classification head, we want features
|
||||
)
|
||||
# Load CLIP vision model from transformers
|
||||
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||
|
||||
# CLIP models have 1 CLS token (no register tokens like DinoV3)
|
||||
# CLIP models have 1 CLS token
|
||||
self.num_non_spatial_tokens = 1
|
||||
|
||||
# Get embed_dim from model config
|
||||
self.embed_dim = self.model.embed_dim
|
||||
self.embed_dim = self.model.config.hidden_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token.
|
||||
|
||||
Preprocessing (resize, crop) is handled by ObservationEncoder
|
||||
"""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
# Extract features using CLIPVisionModel
|
||||
# Input: (B, C, H, W) - already preprocessed
|
||||
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
# Extract CLS token from last_hidden_state (first token)
|
||||
# last_hidden_state shape: (B, sequence_length, hidden_size)
|
||||
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
@@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
"""Create a vision encoder from config.
|
||||
|
||||
Supports any timm model with "clip" or "dinov3" in the backbone name.
|
||||
The encoder type is automatically detected based on the backbone name.
|
||||
"""
|
||||
backbone_name = config.backbone.lower()
|
||||
|
||||
# Check if it's a CLIP model
|
||||
if "clip" in backbone_name:
|
||||
return CLIPEncoder(config)
|
||||
|
||||
# Check if it's a DinoV3 model
|
||||
elif "dinov3" in backbone_name:
|
||||
return DinoV3Encoder(config)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vision backbone: {config.backbone}. "
|
||||
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
|
||||
)
|
||||
|
||||
|
||||
# Registry for easy extension
|
||||
VISION_ENCODER_REGISTRY: dict[str, type] = {
|
||||
"dinov3": DinoV3Encoder,
|
||||
"clip": CLIPEncoder,
|
||||
}
|
||||
|
||||
|
||||
def register_vision_encoder(name: str, encoder_class: type):
|
||||
"""Register a new vision encoder type."""
|
||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
||||
|
||||
|
||||
def get_registered_encoders() -> dict[str, type]:
|
||||
"""Get all registered vision encoder types."""
|
||||
return VISION_ENCODER_REGISTRY.copy()
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
@@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[create_vision_encoder(vision_config) for _ in self.camera_names]
|
||||
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = create_vision_encoder(vision_config)
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
@@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module):
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
"""Setup for vector output."""
|
||||
total_dim = 0
|
||||
|
||||
# Vision features - get CLS token feature dimension
|
||||
@@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module):
|
||||
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
||||
# Expand across temporal dimension to match other features
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
||||
print("Text features shape after unsqueeze and expand:", text_features.shape)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
for vec in conditioning_feats:
|
||||
print(f"Conditioning feature shape: {vec.shape}")
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
||||
|
||||
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|
||||
|
||||
Reference in New Issue
Block a user