mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
remove dino vision encoder and simplify text and vision encoders by removing inheritance structure
This commit is contained in:
+1
-1
@@ -122,7 +122,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
|||||||
# Policies
|
# Policies
|
||||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||||
multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"]
|
multi_task_dit = ["lerobot[transformers-dep]"]
|
||||||
groot = [
|
groot = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
"peft>=0.13.0,<1.0.0",
|
"peft>=0.13.0,<1.0.0",
|
||||||
|
|||||||
@@ -196,16 +196,27 @@ class TransformerConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionEncoderConfig(draccus.ChoiceRegistry):
|
class VisionEncoderConfig:
|
||||||
"""Base configuration for vision encoders.
|
"""Configuration for CLIP vision encoder.
|
||||||
|
|
||||||
|
Uses CLIPVisionModel from transformers library.
|
||||||
|
CLS token usage is handled automatically.
|
||||||
|
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||||
|
by setting resize_shape and crop_shape.
|
||||||
|
|
||||||
All image preprocessing is centralized here:
|
All image preprocessing is centralized here:
|
||||||
1. Resize (optional) - resize images to target resolution
|
1. Resize (optional) - resize images to target resolution
|
||||||
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
||||||
3. Random crop - whether to use random cropping during training
|
3. Random crop - whether to use random cropping during training
|
||||||
|
|
||||||
|
Any CLIP model from transformers can be used. Examples:
|
||||||
|
- openai/clip-vit-base-patch16 (default, 768 dims)
|
||||||
|
- openai/clip-vit-large-patch14 (1024 dims)
|
||||||
|
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders
|
model_name: str = "openai/clip-vit-base-patch16"
|
||||||
|
use_separate_encoder_per_camera: bool = False
|
||||||
|
|
||||||
# Learning rate multiplier for vision encoder parameters
|
# Learning rate multiplier for vision encoder parameters
|
||||||
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
||||||
@@ -217,6 +228,12 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
|
|||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Validate that model name contains "clip" to ensure correct encoder type
|
||||||
|
if "clip" not in self.model_name.lower():
|
||||||
|
raise ValueError(
|
||||||
|
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.resize_shape
|
self.resize_shape
|
||||||
and self.crop_shape
|
and self.crop_shape
|
||||||
@@ -228,70 +245,9 @@ class VisionEncoderConfig(draccus.ChoiceRegistry):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@VisionEncoderConfig.register_subclass("dinov3")
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DinoV3EncoderConfig(VisionEncoderConfig):
|
class TextEncoderConfig:
|
||||||
"""DinoV3 vision encoder configuration.
|
"""Configuration for CLIP text encoder.
|
||||||
|
|
||||||
DinoV3 is a self-supervised Vision Transformer trained by Meta.
|
|
||||||
CLS token usage and spatial feature extraction are handled automatically.
|
|
||||||
|
|
||||||
Any timm model with "dinov3" in the name can be used. Examples:
|
|
||||||
- vit_base_patch16_dinov3.lvd1689m (768 dims)
|
|
||||||
- vit_large_patch14_dinov3.lvd142m (1024 dims)
|
|
||||||
"""
|
|
||||||
|
|
||||||
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__post_init__()
|
|
||||||
# Validate that backbone name contains "dinov3" to ensure correct encoder type
|
|
||||||
if "dinov3" not in self.backbone.lower():
|
|
||||||
raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'")
|
|
||||||
|
|
||||||
|
|
||||||
@VisionEncoderConfig.register_subclass("clip")
|
|
||||||
@dataclass
|
|
||||||
class CLIPVisionEncoderConfig(VisionEncoderConfig):
|
|
||||||
"""CLIP vision encoder configuration.
|
|
||||||
|
|
||||||
CLIP is a vision-language model trained by OpenAI.
|
|
||||||
CLS token usage is handled automatically.
|
|
||||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
|
||||||
by setting resize_shape and crop_shape.
|
|
||||||
|
|
||||||
Any timm model with "clip" in the name can be used. Examples:
|
|
||||||
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
|
|
||||||
- vit_large_patch14_clip_224.openai (1024 dims)
|
|
||||||
"""
|
|
||||||
|
|
||||||
backbone: str = "vit_base_patch16_clip_224.openai"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__post_init__()
|
|
||||||
# Validate that backbone name contains "clip" to ensure correct encoder type
|
|
||||||
if "clip" not in self.backbone.lower():
|
|
||||||
raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextEncoderConfig(draccus.ChoiceRegistry):
|
|
||||||
"""Base configuration for text encoders.
|
|
||||||
|
|
||||||
If a text encoder is set in ObservationEncoderConfig, text conditioning
|
|
||||||
is automatically enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@TextEncoderConfig.register_subclass("clip")
|
|
||||||
@dataclass
|
|
||||||
class CLIPTextEncoderConfig(TextEncoderConfig):
|
|
||||||
"""CLIP text encoder for task conditioning.
|
|
||||||
|
|
||||||
Uses CLIP's text encoder to embed task descriptions, which are then
|
Uses CLIP's text encoder to embed task descriptions, which are then
|
||||||
used to condition the policy. The text embeddings are processed by
|
used to condition the policy. The text embeddings are processed by
|
||||||
@@ -306,7 +262,6 @@ class CLIPTextEncoderConfig(TextEncoderConfig):
|
|||||||
model: str = "openai/clip-vit-base-patch16"
|
model: str = "openai/clip-vit-base-patch16"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
|
||||||
# Validate that model name contains "clip" to ensure correct encoder type
|
# Validate that model name contains "clip" to ensure correct encoder type
|
||||||
if "clip" not in self.model.lower():
|
if "clip" not in self.model.lower():
|
||||||
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
||||||
@@ -317,11 +272,11 @@ class ObservationEncoderConfig:
|
|||||||
"""Top-level configuration for observation encoding.
|
"""Top-level configuration for observation encoding.
|
||||||
|
|
||||||
This config combines:
|
This config combines:
|
||||||
- Vision encoding (required): DinoV3 or CLIP vision encoder
|
- Vision encoding (required): CLIP vision encoder from transformers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig)
|
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
|
||||||
text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig)
|
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||||
|
|||||||
@@ -19,99 +19,44 @@
|
|||||||
Handles vision encoding, text encoding, robot state, and environment state.
|
Handles vision encoding, text encoding, robot state, and environment state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import timm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||||
|
|
||||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
class BaseVisionEncoder(ABC):
|
class CLIPVisionEncoder(nn.Module):
|
||||||
"""Abstract base class for vision encoders."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""Encode RGB image to feature maps."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_output_shape(self) -> tuple:
|
|
||||||
"""Get the output shape (C', H', W')."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
|
|
||||||
"""DinoV3 vision encoder using the CLS token for global image representation."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.model_name = config.backbone
|
|
||||||
|
|
||||||
# Create the timm model
|
|
||||||
self.model = timm.create_model(
|
|
||||||
self.model_name,
|
|
||||||
pretrained=True,
|
|
||||||
num_classes=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
|
|
||||||
self.embed_dim = self.model.embed_dim
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""Encode RGB image to feature maps."""
|
|
||||||
# Extract all features
|
|
||||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
|
||||||
|
|
||||||
# Use only the CLS token (first token)
|
|
||||||
cls_token = features[:, 0] # (B, embed_dim)
|
|
||||||
b, embed_dim = cls_token.shape
|
|
||||||
|
|
||||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
|
||||||
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
|
||||||
return cls_features
|
|
||||||
|
|
||||||
def get_output_shape(self) -> tuple:
|
|
||||||
return (self.embed_dim, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
|
||||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, model_name: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.model_name = model_name
|
||||||
self.model_name = config.backbone
|
|
||||||
|
|
||||||
# Create the timm model
|
# Load CLIP vision model from transformers
|
||||||
self.model = timm.create_model(
|
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||||
self.model_name,
|
|
||||||
pretrained=True,
|
|
||||||
num_classes=0, # Remove classification head, we want features
|
|
||||||
)
|
|
||||||
|
|
||||||
# CLIP models have 1 CLS token (no register tokens like DinoV3)
|
# CLIP models have 1 CLS token
|
||||||
self.num_non_spatial_tokens = 1
|
self.num_non_spatial_tokens = 1
|
||||||
|
|
||||||
# Get embed_dim from model config
|
# Get embed_dim from model config
|
||||||
self.embed_dim = self.model.embed_dim
|
self.embed_dim = self.model.config.hidden_size
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Encode RGB image to CLS token.
|
"""Encode RGB image to CLS token.
|
||||||
|
|
||||||
Preprocessing (resize, crop) is handled by ObservationEncoder
|
Preprocessing (resize, crop) is handled by ObservationEncoder
|
||||||
"""
|
"""
|
||||||
# Extract all features
|
# Extract features using CLIPVisionModel
|
||||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
# Input: (B, C, H, W) - already preprocessed
|
||||||
|
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||||
|
|
||||||
# Use only the CLS token (first token)
|
# Extract CLS token from last_hidden_state (first token)
|
||||||
cls_token = features[:, 0] # (B, embed_dim)
|
# last_hidden_state shape: (B, sequence_length, hidden_size)
|
||||||
|
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
|
||||||
b, embed_dim = cls_token.shape
|
b, embed_dim = cls_token.shape
|
||||||
|
|
||||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||||
@@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
|||||||
return (self.embed_dim, 1, 1)
|
return (self.embed_dim, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
|
||||||
"""Create a vision encoder from config.
|
|
||||||
|
|
||||||
Supports any timm model with "clip" or "dinov3" in the backbone name.
|
|
||||||
The encoder type is automatically detected based on the backbone name.
|
|
||||||
"""
|
|
||||||
backbone_name = config.backbone.lower()
|
|
||||||
|
|
||||||
# Check if it's a CLIP model
|
|
||||||
if "clip" in backbone_name:
|
|
||||||
return CLIPEncoder(config)
|
|
||||||
|
|
||||||
# Check if it's a DinoV3 model
|
|
||||||
elif "dinov3" in backbone_name:
|
|
||||||
return DinoV3Encoder(config)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported vision backbone: {config.backbone}. "
|
|
||||||
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Registry for easy extension
|
|
||||||
VISION_ENCODER_REGISTRY: dict[str, type] = {
|
|
||||||
"dinov3": DinoV3Encoder,
|
|
||||||
"clip": CLIPEncoder,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def register_vision_encoder(name: str, encoder_class: type):
|
|
||||||
"""Register a new vision encoder type."""
|
|
||||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_encoders() -> dict[str, type]:
|
|
||||||
"""Get all registered vision encoder types."""
|
|
||||||
return VISION_ENCODER_REGISTRY.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(nn.Module):
|
class CLIPTextEncoder(nn.Module):
|
||||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||||
@@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module):
|
|||||||
|
|
||||||
if vision_config.use_separate_encoder_per_camera:
|
if vision_config.use_separate_encoder_per_camera:
|
||||||
self.vision_encoders = nn.ModuleList(
|
self.vision_encoders = nn.ModuleList(
|
||||||
[create_vision_encoder(vision_config) for _ in self.camera_names]
|
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
||||||
)
|
)
|
||||||
self.vision_encoder = None
|
self.vision_encoder = None
|
||||||
else:
|
else:
|
||||||
self.vision_encoder = create_vision_encoder(vision_config)
|
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
||||||
self.vision_encoders = None
|
self.vision_encoders = None
|
||||||
else:
|
else:
|
||||||
self.vision_encoder = None
|
self.vision_encoder = None
|
||||||
@@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module):
|
|||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
|
|
||||||
def _setup_vector_output(self):
|
def _setup_vector_output(self):
|
||||||
"""Setup for vector output."""
|
|
||||||
total_dim = 0
|
total_dim = 0
|
||||||
|
|
||||||
# Vision features - get CLS token feature dimension
|
# Vision features - get CLS token feature dimension
|
||||||
@@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module):
|
|||||||
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
||||||
# Expand across temporal dimension to match other features
|
# Expand across temporal dimension to match other features
|
||||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
||||||
print("Text features shape after unsqueeze and expand:", text_features.shape)
|
|
||||||
conditioning_feats.append(text_features)
|
conditioning_feats.append(text_features)
|
||||||
|
|
||||||
for vec in conditioning_feats:
|
|
||||||
print(f"Conditioning feature shape: {vec.shape}")
|
|
||||||
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
||||||
|
|
||||||
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|
||||||
|
|||||||
Reference in New Issue
Block a user