From adabb37af6d73d49106f578a3641b6aa383a60db Mon Sep 17 00:00:00 2001 From: Bryson Jones Date: Wed, 10 Dec 2025 11:09:37 -0800 Subject: [PATCH] remove dino vision encoder and simplify text and vision encoders by removing inheritance structure --- pyproject.toml | 2 +- .../configuration_multi_task_dit.py | 95 ++++--------- .../modules/observation_encoder.py | 131 +++--------------- .../test_multi_task_dit.py} | 0 4 files changed, 42 insertions(+), 186 deletions(-) rename tests/policies/{test_multi_task_dit_policy.py => multi_task_dit/test_multi_task_dit.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 0413f6f2b..2905d8931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] # Policies pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] -multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"] +multi_task_dit = ["lerobot[transformers-dep]"] groot = [ "lerobot[transformers-dep]", "peft>=0.13.0,<1.0.0", diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 09d16bbed..2089b1372 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -196,16 +196,27 @@ class TransformerConfig: @dataclass -class VisionEncoderConfig(draccus.ChoiceRegistry): - """Base configuration for vision encoders. +class VisionEncoderConfig: + """Configuration for CLIP vision encoder. + + Uses CLIPVisionModel from transformers library. + CLS token usage is handled automatically. + CLIP's internal preprocessing (resize to 224x224) can be overridden + by setting resize_shape and crop_shape. All image preprocessing is centralized here: 1. Resize (optional) - resize images to target resolution 2. Crop (optional) - crop after resize, must be smaller than resize_shape 3. Random crop - whether to use random cropping during training + + Any CLIP model from transformers can be used. Examples: + - openai/clip-vit-base-patch16 (default, 768 dims) + - openai/clip-vit-large-patch14 (1024 dims) + - laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model) """ - use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders + model_name: str = "openai/clip-vit-base-patch16" + use_separate_encoder_per_camera: bool = False # Learning rate multiplier for vision encoder parameters # Vision encoder learning rate = optimizer_lr * lr_multiplier @@ -217,6 +228,12 @@ class VisionEncoderConfig(draccus.ChoiceRegistry): crop_is_random: bool = True def __post_init__(self): + # Validate that model name contains "clip" to ensure correct encoder type + if "clip" not in self.model_name.lower(): + raise ValueError( + f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'" + ) + if ( self.resize_shape and self.crop_shape @@ -228,70 +245,9 @@ class VisionEncoderConfig(draccus.ChoiceRegistry): ) -@VisionEncoderConfig.register_subclass("dinov3") @dataclass -class DinoV3EncoderConfig(VisionEncoderConfig): - """DinoV3 vision encoder configuration. - - DinoV3 is a self-supervised Vision Transformer trained by Meta. - CLS token usage and spatial feature extraction are handled automatically. - - Any timm model with "dinov3" in the name can be used. Examples: - - vit_base_patch16_dinov3.lvd1689m (768 dims) - - vit_large_patch14_dinov3.lvd142m (1024 dims) - """ - - backbone: str = "vit_base_patch16_dinov3.lvd1689m" - - def __post_init__(self): - super().__post_init__() - # Validate that backbone name contains "dinov3" to ensure correct encoder type - if "dinov3" not in self.backbone.lower(): - raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'") - - -@VisionEncoderConfig.register_subclass("clip") -@dataclass -class CLIPVisionEncoderConfig(VisionEncoderConfig): - """CLIP vision encoder configuration. - - CLIP is a vision-language model trained by OpenAI. - CLS token usage is handled automatically. - CLIP's internal preprocessing (resize to 224x224) can be overridden - by setting resize_shape and crop_shape. - - Any timm model with "clip" in the name can be used. Examples: - - vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224) - - vit_large_patch14_clip_224.openai (1024 dims) - """ - - backbone: str = "vit_base_patch16_clip_224.openai" - - def __post_init__(self): - super().__post_init__() - # Validate that backbone name contains "clip" to ensure correct encoder type - if "clip" not in self.backbone.lower(): - raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'") - - -@dataclass -class TextEncoderConfig(draccus.ChoiceRegistry): - """Base configuration for text encoders. - - If a text encoder is set in ObservationEncoderConfig, text conditioning - is automatically enabled. - """ - - pass - - def __post_init__(self): - pass - - -@TextEncoderConfig.register_subclass("clip") -@dataclass -class CLIPTextEncoderConfig(TextEncoderConfig): - """CLIP text encoder for task conditioning. +class TextEncoderConfig: + """Configuration for CLIP text encoder. Uses CLIP's text encoder to embed task descriptions, which are then used to condition the policy. The text embeddings are processed by @@ -306,7 +262,6 @@ class CLIPTextEncoderConfig(TextEncoderConfig): model: str = "openai/clip-vit-base-patch16" def __post_init__(self): - super().__post_init__() # Validate that model name contains "clip" to ensure correct encoder type if "clip" not in self.model.lower(): raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'") @@ -317,11 +272,11 @@ class ObservationEncoderConfig: """Top-level configuration for observation encoding. This config combines: - - Vision encoding (required): DinoV3 or CLIP vision encoder + - Vision encoding (required): CLIP vision encoder from transformers """ - vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig) - text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig) + vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig) + text: TextEncoderConfig = field(default_factory=TextEncoderConfig) @PreTrainedConfig.register_subclass("multi_task_dit") diff --git a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py index 1dd7ec43d..794ab71f1 100644 --- a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py +++ b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py @@ -19,99 +19,44 @@ Handles vision encoding, text encoding, robot state, and environment state. """ -from abc import ABC, abstractmethod - import einops -import timm import torch import torch.nn as nn import torchvision from torch import Tensor -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE -class BaseVisionEncoder(ABC): - """Abstract base class for vision encoders.""" - - @abstractmethod - def forward(self, x: Tensor) -> Tensor: - """Encode RGB image to feature maps.""" - pass - - @abstractmethod - def get_output_shape(self) -> tuple: - """Get the output shape (C', H', W').""" - pass - - -class DinoV3Encoder(nn.Module, BaseVisionEncoder): - """DinoV3 vision encoder using the CLS token for global image representation.""" - - def __init__(self, config): - super().__init__() - self.config = config - self.model_name = config.backbone - - # Create the timm model - self.model = timm.create_model( - self.model_name, - pretrained=True, - num_classes=0, - ) - - self.num_non_spatial_tokens = 5 # 1 CLS + 4 register - self.embed_dim = self.model.embed_dim - - def forward(self, x: Tensor) -> Tensor: - """Encode RGB image to feature maps.""" - # Extract all features - features = self.model.forward_features(x) # (B, total_tokens, embed_dim) - - # Use only the CLS token (first token) - cls_token = features[:, 0] # (B, embed_dim) - b, embed_dim = cls_token.shape - - # Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility - cls_features = cls_token.reshape(b, embed_dim, 1, 1) - return cls_features - - def get_output_shape(self) -> tuple: - return (self.embed_dim, 1, 1) - - -class CLIPEncoder(nn.Module, BaseVisionEncoder): +class CLIPVisionEncoder(nn.Module): """CLIP vision encoder using the CLS token for global image representation.""" - def __init__(self, config): + def __init__(self, model_name: str): super().__init__() - self.config = config - self.model_name = config.backbone + self.model_name = model_name - # Create the timm model - self.model = timm.create_model( - self.model_name, - pretrained=True, - num_classes=0, # Remove classification head, we want features - ) + # Load CLIP vision model from transformers + self.model = CLIPVisionModel.from_pretrained(self.model_name) - # CLIP models have 1 CLS token (no register tokens like DinoV3) + # CLIP models have 1 CLS token self.num_non_spatial_tokens = 1 # Get embed_dim from model config - self.embed_dim = self.model.embed_dim + self.embed_dim = self.model.config.hidden_size def forward(self, x: Tensor) -> Tensor: """Encode RGB image to CLS token. Preprocessing (resize, crop) is handled by ObservationEncoder """ - # Extract all features - features = self.model.forward_features(x) # (B, total_tokens, embed_dim) + # Extract features using CLIPVisionModel + # Input: (B, C, H, W) - already preprocessed + outputs = self.model(pixel_values=x, output_hidden_states=False) - # Use only the CLS token (first token) - cls_token = features[:, 0] # (B, embed_dim) + # Extract CLS token from last_hidden_state (first token) + # last_hidden_state shape: (B, sequence_length, hidden_size) + cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim) b, embed_dim = cls_token.shape # Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility @@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder): return (self.embed_dim, 1, 1) -def create_vision_encoder(config) -> BaseVisionEncoder: - """Create a vision encoder from config. - - Supports any timm model with "clip" or "dinov3" in the backbone name. - The encoder type is automatically detected based on the backbone name. - """ - backbone_name = config.backbone.lower() - - # Check if it's a CLIP model - if "clip" in backbone_name: - return CLIPEncoder(config) - - # Check if it's a DinoV3 model - elif "dinov3" in backbone_name: - return DinoV3Encoder(config) - - else: - raise ValueError( - f"Unsupported vision backbone: {config.backbone}. " - f"Currently supported: any timm model with 'dinov3' or 'clip' in the name" - ) - - -# Registry for easy extension -VISION_ENCODER_REGISTRY: dict[str, type] = { - "dinov3": DinoV3Encoder, - "clip": CLIPEncoder, -} - - -def register_vision_encoder(name: str, encoder_class: type): - """Register a new vision encoder type.""" - VISION_ENCODER_REGISTRY[name] = encoder_class - - -def get_registered_encoders() -> dict[str, type]: - """Get all registered vision encoder types.""" - return VISION_ENCODER_REGISTRY.copy() - - class CLIPTextEncoder(nn.Module): """Supports any HuggingFace CLIP model. The encoder weights are frozen, and a learnable projection layer maps the CLIP embeddings to the desired dimension. @@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module): if vision_config.use_separate_encoder_per_camera: self.vision_encoders = nn.ModuleList( - [create_vision_encoder(vision_config) for _ in self.camera_names] + [CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names] ) self.vision_encoder = None else: - self.vision_encoder = create_vision_encoder(vision_config) + self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name) self.vision_encoders = None else: self.vision_encoder = None @@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module): self.do_crop = False def _setup_vector_output(self): - """Setup for vector output.""" total_dim = 0 # Vision features - get CLS token feature dimension @@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module): text_features = self.text_encoder(batch["task"]) # (B, text_dim) # Expand across temporal dimension to match other features text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim) - print("Text features shape after unsqueeze and expand:", text_features.shape) conditioning_feats.append(text_features) - for vec in conditioning_feats: - print(f"Conditioning feature shape: {vec.shape}") combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim) return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim) diff --git a/tests/policies/test_multi_task_dit_policy.py b/tests/policies/multi_task_dit/test_multi_task_dit.py similarity index 100% rename from tests/policies/test_multi_task_dit_policy.py rename to tests/policies/multi_task_dit/test_multi_task_dit.py