mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
import draccus
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
@@ -24,295 +23,84 @@ from lerobot.optim.optimizers import AdamConfig
|
|||||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ObjectiveConfig(draccus.ChoiceRegistry):
|
|
||||||
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ObjectiveConfig.register_subclass("diffusion")
|
|
||||||
@dataclass
|
|
||||||
class DiffusionConfig(ObjectiveConfig):
|
|
||||||
"""Configuration for standard diffusion model training and inference.
|
|
||||||
|
|
||||||
These parameters control the noise scheduling and denoising process for
|
|
||||||
standard DDPM/DDIM diffusion models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
objective_name: str = field(default="diffusion", init=False)
|
|
||||||
|
|
||||||
# Noise scheduler configuration - controls diffusion process
|
|
||||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
|
||||||
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
|
|
||||||
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
|
|
||||||
beta_start: float = 0.0001 # Small initial noise level
|
|
||||||
beta_end: float = 0.02 # Moderate final noise level
|
|
||||||
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
|
|
||||||
clip_sample: bool = True # Prevent extreme action values
|
|
||||||
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
|
|
||||||
|
|
||||||
# Inference configuration
|
|
||||||
num_inference_steps: int | None = None # Default to num_train_timesteps
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Validate diffusion-specific parameters."""
|
|
||||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.prediction_type not in ["epsilon", "sample"]:
|
|
||||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
|
||||||
|
|
||||||
if self.num_train_timesteps <= 0:
|
|
||||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
|
||||||
|
|
||||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
|
||||||
raise ValueError(
|
|
||||||
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
|
|
||||||
f"got {self.beta_start}, {self.beta_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TimestepSamplingConfig(draccus.ChoiceRegistry):
|
|
||||||
"""Base configuration for timestep sampling strategies during training."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@TimestepSamplingConfig.register_subclass("uniform")
|
|
||||||
@dataclass
|
|
||||||
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
|
|
||||||
"""Uniform timestep sampling from [0, 1]."""
|
|
||||||
|
|
||||||
strategy_name: str = field(default="uniform", init=False)
|
|
||||||
|
|
||||||
|
|
||||||
@TimestepSamplingConfig.register_subclass("beta")
|
|
||||||
@dataclass
|
|
||||||
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
|
|
||||||
"""Beta distribution timestep sampling.
|
|
||||||
|
|
||||||
Samples from Beta distribution emphasizing low timesteps (high noise).
|
|
||||||
|
|
||||||
This was inspired on the work from Physical Intelligence PI-0 model,
|
|
||||||
where they suggested the beta distribution for sampling timesteps
|
|
||||||
during training improved sample quality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
strategy_name: str = field(default="beta", init=False)
|
|
||||||
|
|
||||||
s: float = 0.999 # Max timestep threshold for beta sampling
|
|
||||||
alpha: float = 1.5 # Beta distribution alpha parameter
|
|
||||||
beta: float = 1.0 # Beta distribution beta parameter
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not (0.0 < self.s <= 1.0):
|
|
||||||
raise ValueError(f"s must be in (0, 1], got {self.s}")
|
|
||||||
|
|
||||||
if self.alpha <= 0:
|
|
||||||
raise ValueError(f"alpha must be positive, got {self.alpha}")
|
|
||||||
|
|
||||||
if self.beta <= 0:
|
|
||||||
raise ValueError(f"beta must be positive, got {self.beta}")
|
|
||||||
|
|
||||||
|
|
||||||
@ObjectiveConfig.register_subclass("flow_matching")
|
|
||||||
@dataclass
|
|
||||||
class FlowMatchingConfig(ObjectiveConfig):
|
|
||||||
"""Configuration for flow matching training and inference.
|
|
||||||
|
|
||||||
These parameters control the velocity field learning and ODE integration
|
|
||||||
process for flow matching models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
objective_name: str = field(default="flow_matching", init=False)
|
|
||||||
|
|
||||||
# Flow path construction
|
|
||||||
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
|
|
||||||
|
|
||||||
# ODE integration for inference
|
|
||||||
num_integration_steps: int = (
|
|
||||||
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
|
|
||||||
)
|
|
||||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
|
||||||
|
|
||||||
# Timestep sampling strategy for training
|
|
||||||
# Beta distribution found to be the most effective in practice, so it is the default
|
|
||||||
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not (0.0 <= self.sigma_min <= 1.0):
|
|
||||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
|
||||||
|
|
||||||
if self.num_integration_steps <= 0:
|
|
||||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
|
||||||
|
|
||||||
if self.integration_method not in ["euler", "rk4"]:
|
|
||||||
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TransformerConfig:
|
|
||||||
"""Configuration for Transformer-based prediction model.
|
|
||||||
|
|
||||||
These parameters control the transformer architecture used for noise/velocity
|
|
||||||
prediction in diffusion and flow matching models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Transformer architecture parameters
|
|
||||||
hidden_dim: int = 512 # Hidden dimension of transformer
|
|
||||||
num_layers: int = 6 # Number of transformer layers
|
|
||||||
num_heads: int = 8 # Number of attention heads
|
|
||||||
dropout: float = 0.1 # Dropout rate
|
|
||||||
use_positional_encoding: bool = False # Whether to use absolute positional encoding
|
|
||||||
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
|
||||||
|
|
||||||
# RoPE (Rotary Position Embedding) configuration
|
|
||||||
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
|
|
||||||
rope_base: float = 10000.0 # Base frequency for RoPE computation
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Validate Transformer-specific parameters."""
|
|
||||||
if self.hidden_dim <= 0:
|
|
||||||
raise ValueError("hidden_dim must be positive")
|
|
||||||
|
|
||||||
if self.num_layers <= 0:
|
|
||||||
raise ValueError("num_layers must be positive")
|
|
||||||
|
|
||||||
if self.num_heads <= 0:
|
|
||||||
raise ValueError("num_heads must be positive")
|
|
||||||
|
|
||||||
if self.hidden_dim % self.num_heads != 0:
|
|
||||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
|
||||||
|
|
||||||
if not (0.0 <= self.dropout <= 1.0):
|
|
||||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
|
||||||
|
|
||||||
if self.diffusion_step_embed_dim <= 0:
|
|
||||||
raise ValueError("diffusion_step_embed_dim must be positive")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VisionEncoderConfig:
|
|
||||||
"""Configuration for CLIP vision encoder.
|
|
||||||
|
|
||||||
Uses CLIPVisionModel from transformers library.
|
|
||||||
CLS token usage is handled automatically.
|
|
||||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
|
||||||
by setting resize_shape and crop_shape.
|
|
||||||
|
|
||||||
All image preprocessing is centralized here:
|
|
||||||
1. Resize (optional) - resize images to target resolution
|
|
||||||
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
|
||||||
3. Random crop - whether to use random cropping during training
|
|
||||||
|
|
||||||
Any CLIP model from transformers can be used. Examples:
|
|
||||||
- openai/clip-vit-base-patch16 (default, 768 dims)
|
|
||||||
- openai/clip-vit-large-patch14 (1024 dims)
|
|
||||||
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_name: str = "openai/clip-vit-base-patch16"
|
|
||||||
use_separate_encoder_per_camera: bool = False
|
|
||||||
|
|
||||||
# Learning rate multiplier for vision encoder parameters
|
|
||||||
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
|
||||||
lr_multiplier: float = 0.1
|
|
||||||
|
|
||||||
# Image preprocessing (centralized)
|
|
||||||
resize_shape: tuple[int, int] | None = None
|
|
||||||
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
|
|
||||||
crop_is_random: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Validate that model name contains "clip" to ensure correct encoder type
|
|
||||||
if "clip" not in self.model_name.lower():
|
|
||||||
raise ValueError(
|
|
||||||
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.resize_shape
|
|
||||||
and self.crop_shape
|
|
||||||
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"crop_shape {self.crop_shape} must be smaller than or equal to "
|
|
||||||
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextEncoderConfig:
|
|
||||||
"""Configuration for CLIP text encoder.
|
|
||||||
|
|
||||||
Uses CLIP's text encoder to embed task descriptions, which are then
|
|
||||||
used to condition the policy. The text embeddings are processed by
|
|
||||||
a learnable projection layer before being concatenated into the
|
|
||||||
conditioning vector.
|
|
||||||
|
|
||||||
Any HuggingFace CLIP model can be used. Examples:
|
|
||||||
- openai/clip-vit-base-patch16 (default)
|
|
||||||
- openai/clip-vit-large-patch14
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str = "openai/clip-vit-base-patch16"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Validate that model name contains "clip" to ensure correct encoder type
|
|
||||||
if "clip" not in self.model.lower():
|
|
||||||
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ObservationEncoderConfig:
|
|
||||||
"""Top-level configuration for observation encoding.
|
|
||||||
|
|
||||||
This config combines:
|
|
||||||
- Vision encoding (required): CLIP vision encoder from transformers
|
|
||||||
"""
|
|
||||||
|
|
||||||
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
|
|
||||||
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||||
"""
|
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||||
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
|
|
||||||
|
A transformer-based policy that supports both diffusion and flow matching objectives
|
||||||
|
for multi-task robot learning with text and vision conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Temporal structure - controls how the policy processes time and predicts actions
|
n_obs_steps: int = 2 # Number of observation steps for temporal context
|
||||||
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t)
|
horizon: int = 32 # Number of action steps to predict
|
||||||
horizon: int = 100 # predicted action steps into the future
|
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||||
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
|
|
||||||
|
|
||||||
# Normalization strategy - critical for diffusion model performance
|
# Objective Selection
|
||||||
|
objective: Literal["diffusion", "flow_matching"] = "diffusion"
|
||||||
|
|
||||||
|
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||||
|
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||||
|
num_train_timesteps: int = 100 # Number of diffusion timesteps
|
||||||
|
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
|
||||||
|
beta_start: float = 0.0001 # Starting noise level
|
||||||
|
beta_end: float = 0.02 # Ending noise level
|
||||||
|
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
|
||||||
|
clip_sample: bool = True # Clip samples during denoising
|
||||||
|
clip_sample_range: float = 1.0 # Clipping range [-x, x]
|
||||||
|
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
|
||||||
|
|
||||||
|
# --- Flow Matching-specific (used when objective="flow_matching") ---
|
||||||
|
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||||
|
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||||
|
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||||
|
timestep_sampling_strategy: Literal["uniform", "beta"] = "beta"
|
||||||
|
|
||||||
|
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||||
|
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||||
|
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
|
||||||
|
|
||||||
|
# Transformer Architecture
|
||||||
|
hidden_dim: int = 512 # Transformer hidden dimension
|
||||||
|
num_layers: int = 6 # Number of transformer layers
|
||||||
|
num_heads: int = 8 # Number of attention heads
|
||||||
|
dropout: float = 0.1 # Dropout rate
|
||||||
|
use_positional_encoding: bool = False # Use absolute positional encoding
|
||||||
|
timestep_embed_dim: int = 256 # Timestep embedding dimension
|
||||||
|
use_rope: bool = True # Use Rotary Position Embedding
|
||||||
|
rope_base: float = 10000.0 # RoPE base frequency
|
||||||
|
|
||||||
|
# Vision Encoder (CLIP)
|
||||||
|
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||||
|
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||||
|
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||||
|
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||||
|
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||||
|
image_crop_is_random: bool = True # Random crop during training, center at inference
|
||||||
|
|
||||||
|
# Text Encoder (CLIP)
|
||||||
|
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||||
|
|
||||||
|
# Normalization
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1
|
# Training/Optimizer
|
||||||
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
|
|
||||||
transformer: TransformerConfig = field(default_factory=TransformerConfig)
|
|
||||||
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
|
|
||||||
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
|
|
||||||
|
|
||||||
# training optimizer and scheduler hyperparameters
|
|
||||||
optimizer_lr: float = 2e-5
|
optimizer_lr: float = 2e-5
|
||||||
optimizer_betas: tuple = (0.95, 0.999)
|
optimizer_betas: tuple = (0.95, 0.999)
|
||||||
optimizer_eps: float = 1e-8
|
optimizer_eps: float = 1e-8
|
||||||
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal
|
optimizer_weight_decay: float = 0.0
|
||||||
scheduler_name: str = "cosine"
|
scheduler_name: str = "cosine"
|
||||||
scheduler_warmup_steps: int = 0 # No warmup found to be optimal
|
scheduler_warmup_steps: int = 0
|
||||||
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
|
# Auto-calculated
|
||||||
|
drop_n_last_frames: int | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@@ -320,11 +108,78 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
if self.drop_n_last_frames is None:
|
if self.drop_n_last_frames is None:
|
||||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamConfig:
|
self._validate()
|
||||||
"""Return Adam optimizer configuration optimized for diffusion training.
|
|
||||||
|
|
||||||
Note: Vision encoder learning rate is set separately via get_optim_params.
|
def _validate(self):
|
||||||
"""
|
"""Validate configuration parameters."""
|
||||||
|
# Transformer validation
|
||||||
|
if self.hidden_dim <= 0:
|
||||||
|
raise ValueError("hidden_dim must be positive")
|
||||||
|
if self.num_layers <= 0:
|
||||||
|
raise ValueError("num_layers must be positive")
|
||||||
|
if self.num_heads <= 0:
|
||||||
|
raise ValueError("num_heads must be positive")
|
||||||
|
if self.hidden_dim % self.num_heads != 0:
|
||||||
|
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||||
|
if not (0.0 <= self.dropout <= 1.0):
|
||||||
|
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
# Vision encoder validation
|
||||||
|
if "clip" not in self.vision_encoder_name.lower():
|
||||||
|
raise ValueError(
|
||||||
|
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.image_resize_shape
|
||||||
|
and self.image_crop_shape
|
||||||
|
and (
|
||||||
|
self.image_crop_shape[0] > self.image_resize_shape[0]
|
||||||
|
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Text encoder validation
|
||||||
|
if "clip" not in self.text_encoder_name.lower():
|
||||||
|
raise ValueError(
|
||||||
|
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Objective-specific validation
|
||||||
|
if self.objective == "diffusion":
|
||||||
|
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||||
|
)
|
||||||
|
if self.prediction_type not in ["epsilon", "sample"]:
|
||||||
|
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||||
|
if self.num_train_timesteps <= 0:
|
||||||
|
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||||
|
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||||
|
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
|
||||||
|
|
||||||
|
elif self.objective == "flow_matching":
|
||||||
|
if not (0.0 <= self.sigma_min <= 1.0):
|
||||||
|
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||||
|
if self.num_integration_steps <= 0:
|
||||||
|
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||||
|
if self.integration_method not in ["euler", "rk4"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
|
||||||
|
)
|
||||||
|
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
|
||||||
|
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
|
||||||
|
if self.timestep_sampling_strategy == "beta":
|
||||||
|
if not (0.0 < self.timestep_sampling_s <= 1.0):
|
||||||
|
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
|
||||||
|
if self.timestep_sampling_alpha <= 0:
|
||||||
|
raise ValueError("timestep_sampling_alpha must be positive")
|
||||||
|
if self.timestep_sampling_beta <= 0:
|
||||||
|
raise ValueError("timestep_sampling_beta must be positive")
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
return AdamConfig(
|
return AdamConfig(
|
||||||
lr=self.optimizer_lr,
|
lr=self.optimizer_lr,
|
||||||
betas=self.optimizer_betas,
|
betas=self.optimizer_betas,
|
||||||
@@ -333,7 +188,6 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||||
"""Return learning rate scheduler configuration."""
|
|
||||||
return DiffuserSchedulerConfig(
|
return DiffuserSchedulerConfig(
|
||||||
name=self.scheduler_name,
|
name=self.scheduler_name,
|
||||||
num_warmup_steps=self.scheduler_warmup_steps,
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
@@ -341,56 +195,41 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate that required input features are present and properly configured."""
|
"""Validate that required input features are present and properly configured."""
|
||||||
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state
|
if self.image_crop_shape is not None:
|
||||||
# This allows for testing and simple state-only policies
|
|
||||||
|
|
||||||
# Validate crop shape fits within image dimensions
|
|
||||||
crop_shape = self.observation_encoder.vision.crop_shape
|
|
||||||
if crop_shape is not None:
|
|
||||||
for key, image_ft in self.image_features.items():
|
for key, image_ft in self.image_features.items():
|
||||||
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]:
|
if (
|
||||||
|
self.image_crop_shape[0] > image_ft.shape[1]
|
||||||
|
or self.image_crop_shape[1] > image_ft.shape[2]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`crop_shape` should fit within the images shapes. Got {crop_shape} "
|
f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
|
||||||
f"for `crop_shape` and {image_ft.shape} for "
|
f"for '{key}'"
|
||||||
f"`{key}`."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure all images have same shape (current limitation)
|
|
||||||
if len(self.image_features) > 0:
|
if len(self.image_features) > 0:
|
||||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
first_key, first_ft = next(iter(self.image_features.items()))
|
||||||
for key, image_ft in self.image_features.items():
|
for key, image_ft in self.image_features.items():
|
||||||
if image_ft.shape != first_image_ft.shape:
|
if image_ft.shape != first_ft.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def model_objective(self) -> str:
|
|
||||||
return self.objective.objective_name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_diffusion(self) -> bool:
|
def is_diffusion(self) -> bool:
|
||||||
return isinstance(self.objective, DiffusionConfig)
|
return self.objective == "diffusion"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_flow_matching(self) -> bool:
|
def is_flow_matching(self) -> bool:
|
||||||
return isinstance(self.objective, FlowMatchingConfig)
|
return self.objective == "flow_matching"
|
||||||
|
|
||||||
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
|
|
||||||
"""Get the objective-specific configuration with proper typing."""
|
|
||||||
return self.objective
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
"""Delta indices for stacking observations. Provides temporal context."""
|
|
||||||
return list(range(1 - self.n_obs_steps, 1))
|
return list(range(1 - self.n_obs_steps, 1))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
"""Delta indices for action horizon prediction."""
|
|
||||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
"""Indices for reward deltas (not used in diffusion policy)."""
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -52,23 +52,22 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
action_dim = config.action_feature.shape[0]
|
action_dim = config.action_feature.shape[0]
|
||||||
horizon = config.horizon
|
horizon = config.horizon
|
||||||
|
|
||||||
self.model_objective = config.model_objective
|
|
||||||
if config.is_diffusion:
|
if config.is_diffusion:
|
||||||
self.objective = DiffusionObjective(
|
self.objective = DiffusionObjective(
|
||||||
config.get_objective_config(),
|
config,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||||
)
|
)
|
||||||
elif config.is_flow_matching:
|
elif config.is_flow_matching:
|
||||||
self.objective = FlowMatchingObjective(
|
self.objective = FlowMatchingObjective(
|
||||||
config.get_objective_config(),
|
config,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model_objective: {self.model_objective}")
|
raise ValueError(f"Unsupported objective: {config.objective}")
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -90,7 +89,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
{"params": non_vision_params},
|
{"params": non_vision_params},
|
||||||
{
|
{
|
||||||
"params": vision_encoder_params,
|
"params": vision_encoder_params,
|
||||||
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier,
|
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -118,8 +117,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
if self.config.env_state_feature:
|
if self.config.env_state_feature:
|
||||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
if self.config.observation_encoder.text:
|
# Always include task queue for text conditioning
|
||||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
|||||||
@@ -14,14 +14,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""Objective implementations for Multi-Task DiT policy.
|
||||||
This module contains a base objective class, and implementation of the objective
|
|
||||||
classes for use in the Multi-Task Diffusion Transformer Policy.
|
|
||||||
|
|
||||||
Architecture:
|
- DiffusionObjective: Standard DDPM/DDIM diffusion
|
||||||
- BaseObjective: Abstract interface definition
|
- FlowMatchingObjective: Flow matching with ODE integration
|
||||||
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
|
|
||||||
- FlowMatchingObjective: Implements flow matching objective
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -35,10 +31,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
|
|
||||||
class BaseObjective(ABC):
|
class BaseObjective(ABC):
|
||||||
"""
|
"""Base class for objectives used in Multi-Task DiT policy."""
|
||||||
Base class for objectives used in Multi-Task DiT policy.
|
|
||||||
Defines the interface for training loss computation and conditional sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, action_dim: int, horizon: int):
|
def __init__(self, config, action_dim: int, horizon: int):
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -47,38 +40,17 @@ class BaseObjective(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||||
"""Compute training loss for the objective.
|
"""Compute training loss."""
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The prediction network
|
|
||||||
batch: Training batch with observations and actions
|
|
||||||
conditioning_vec: Encoded observation features for conditioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Scalar loss tensor
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||||
"""Generate action samples conditioned on embedded observation features.
|
"""Generate action samples conditioned on observations."""
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The prediction network
|
|
||||||
batch_size: The number of samples to generate
|
|
||||||
conditioning_vec: Encoded observation features for conditioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated action sequences (batch_size, horizon, action_dim)
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DiffusionObjective(BaseObjective):
|
class DiffusionObjective(BaseObjective):
|
||||||
"""Standard diffusion (DDPM/DDIM) objective implementation.
|
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||||
|
|
||||||
Contains the noise scheduler, training loss, and conditional sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||||
super().__init__(config, action_dim, horizon)
|
super().__init__(config, action_dim, horizon)
|
||||||
@@ -90,8 +62,8 @@ class DiffusionObjective(BaseObjective):
|
|||||||
"beta_start": config.beta_start,
|
"beta_start": config.beta_start,
|
||||||
"beta_end": config.beta_end,
|
"beta_end": config.beta_end,
|
||||||
"beta_schedule": config.beta_schedule,
|
"beta_schedule": config.beta_schedule,
|
||||||
"clip_sample": getattr(config, "clip_sample", True),
|
"clip_sample": config.clip_sample,
|
||||||
"clip_sample_range": getattr(config, "clip_sample_range", 1.0),
|
"clip_sample_range": config.clip_sample_range,
|
||||||
"prediction_type": config.prediction_type,
|
"prediction_type": config.prediction_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +77,7 @@ class DiffusionObjective(BaseObjective):
|
|||||||
# Inference steps default to training steps if not provided
|
# Inference steps default to training steps if not provided
|
||||||
self.num_inference_steps = (
|
self.num_inference_steps = (
|
||||||
config.num_inference_steps
|
config.num_inference_steps
|
||||||
if getattr(config, "num_inference_steps", None) is not None
|
if config.num_inference_steps is not None
|
||||||
else self.noise_scheduler.config.num_train_timesteps
|
else self.noise_scheduler.config.num_train_timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -161,40 +133,29 @@ class DiffusionObjective(BaseObjective):
|
|||||||
|
|
||||||
|
|
||||||
class FlowMatchingObjective(BaseObjective):
|
class FlowMatchingObjective(BaseObjective):
|
||||||
"""
|
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||||
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
|
|
||||||
noise to data. This basically interpolates as path between noise and a trained distribution
|
|
||||||
with a set target velocity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||||
super().__init__(config, action_dim, horizon)
|
super().__init__(config, action_dim, horizon)
|
||||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||||
|
|
||||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||||
"""Sample timesteps according to configured strategy.
|
"""Sample timesteps according to configured strategy."""
|
||||||
|
if self.config.timestep_sampling_strategy == "uniform":
|
||||||
Uniform: Sample t uniformly from [0,1]
|
|
||||||
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
|
|
||||||
"""
|
|
||||||
if self.config.timestep_sampling.strategy_name == "uniform":
|
|
||||||
return torch.rand(batch_size, device=device)
|
return torch.rand(batch_size, device=device)
|
||||||
elif self.config.timestep_sampling.strategy_name == "beta":
|
elif self.config.timestep_sampling_strategy == "beta":
|
||||||
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
|
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
|
||||||
# This emphasizes t near 0 (high noise) when α > β
|
# This emphasizes t near 0 (high noise) when α > β
|
||||||
beta_dist = torch.distributions.Beta(
|
beta_dist = torch.distributions.Beta(
|
||||||
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta
|
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
|
||||||
)
|
)
|
||||||
u = beta_dist.sample((batch_size,)).to(device)
|
u = beta_dist.sample((batch_size,)).to(device)
|
||||||
return self.config.timestep_sampling.s * (1.0 - u)
|
return self.config.timestep_sampling_s * (1.0 - u)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}")
|
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||||
|
|
||||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||||
"""Compute flow matching training loss.
|
"""Compute flow matching training loss."""
|
||||||
|
|
||||||
Trains the model to predict the velocity field along linear interpolation paths.
|
|
||||||
"""
|
|
||||||
data = batch["action"] # Clean action sequences (B, T, D)
|
data = batch["action"] # Clean action sequences (B, T, D)
|
||||||
batch_size = data.shape[0]
|
batch_size = data.shape[0]
|
||||||
device = data.device
|
device = data.device
|
||||||
@@ -217,10 +178,7 @@ class FlowMatchingObjective(BaseObjective):
|
|||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||||
"""Generate actions by integrating the learned velocity field via ODE.
|
"""Generate actions by integrating the learned velocity field via ODE."""
|
||||||
|
|
||||||
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
|
|
||||||
"""
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
dtype = next(model.parameters()).dtype
|
dtype = next(model.parameters()).dtype
|
||||||
|
|
||||||
@@ -244,9 +202,7 @@ class FlowMatchingObjective(BaseObjective):
|
|||||||
def _euler_integrate(
|
def _euler_integrate(
|
||||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)"""
|
||||||
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
|
|
||||||
"""
|
|
||||||
x = x_init
|
x = x_init
|
||||||
|
|
||||||
for i in range(len(time_grid) - 1):
|
for i in range(len(time_grid) - 1):
|
||||||
@@ -268,23 +224,10 @@ class FlowMatchingObjective(BaseObjective):
|
|||||||
def _rk4_integrate(
|
def _rk4_integrate(
|
||||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""4th-order Runge-Kutta integration.
|
"""4th-order Runge-Kutta integration."""
|
||||||
|
|
||||||
Uses 4 velocity evaluations per step:
|
|
||||||
k1 = v(x, t)
|
|
||||||
k2 = v(x + dt·k1/2, t + dt/2)
|
|
||||||
k3 = v(x + dt·k2/2, t + dt/2)
|
|
||||||
k4 = v(x + dt·k3, t + dt)
|
|
||||||
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
|
|
||||||
|
|
||||||
4x slower than Euler but more accurate.
|
|
||||||
|
|
||||||
Note: In practice, this seems to not matter much
|
|
||||||
"""
|
|
||||||
x = x_init
|
x = x_init
|
||||||
|
|
||||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||||
"""dynamics helper to get velocity at (x, t)"""
|
|
||||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||||
|
|||||||
@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(nn.Module):
|
class CLIPTextEncoder(nn.Module):
|
||||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
"""CLIP text encoder with frozen weights and a learnable projection layer."""
|
||||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
vision_config = config.observation_encoder.vision
|
|
||||||
|
|
||||||
self._setup_preprocessing(vision_config)
|
self._setup_preprocessing(config)
|
||||||
|
|
||||||
if config.image_features:
|
if config.image_features:
|
||||||
self.num_cameras = len(config.image_features)
|
self.num_cameras = len(config.image_features)
|
||||||
self.camera_names = list(config.image_features.keys()) # Preserve ordering
|
self.camera_names = list(config.image_features.keys())
|
||||||
|
|
||||||
if vision_config.use_separate_encoder_per_camera:
|
if config.use_separate_encoder_per_camera:
|
||||||
self.vision_encoders = nn.ModuleList(
|
self.vision_encoders = nn.ModuleList(
|
||||||
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||||
)
|
)
|
||||||
self.vision_encoder = None
|
self.vision_encoder = None
|
||||||
else:
|
else:
|
||||||
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||||
self.vision_encoders = None
|
self.vision_encoders = None
|
||||||
else:
|
else:
|
||||||
self.vision_encoder = None
|
self.vision_encoder = None
|
||||||
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.env_state_dim = 0
|
self.env_state_dim = 0
|
||||||
|
|
||||||
text_config = config.observation_encoder.text
|
self.text_dim = config.hidden_dim
|
||||||
self.text_dim = config.transformer.hidden_dim
|
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||||
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
|
|
||||||
|
|
||||||
self._setup_vector_output()
|
self._setup_vector_output()
|
||||||
|
|
||||||
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
|
|||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def _setup_preprocessing(self, vision_config):
|
def _setup_preprocessing(self, config):
|
||||||
"""Setup image preprocessing transforms."""
|
"""Setup image preprocessing transforms."""
|
||||||
if vision_config.resize_shape is not None:
|
if config.image_resize_shape is not None:
|
||||||
self.do_resize = True
|
self.do_resize = True
|
||||||
self.resize = torchvision.transforms.Resize(
|
self.resize = torchvision.transforms.Resize(
|
||||||
size=vision_config.resize_shape,
|
size=config.image_resize_shape,
|
||||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.do_resize = False
|
self.do_resize = False
|
||||||
if vision_config.crop_shape is not None:
|
|
||||||
|
if config.image_crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
|
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||||
if vision_config.crop_is_random:
|
if config.image_crop_is_random:
|
||||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||||
else:
|
else:
|
||||||
self.maybe_random_crop = self.center_crop
|
self.maybe_random_crop = self.center_crop
|
||||||
else:
|
else:
|
||||||
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
|
|||||||
|
|
||||||
# Vision features - get CLS token feature dimension
|
# Vision features - get CLS token feature dimension
|
||||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
|
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||||
|
|
||||||
# Get output shape from encoder (deterministic for CLS tokens)
|
# Get output shape from encoder (deterministic for CLS tokens)
|
||||||
feature_map_shape = encoder_to_check.get_output_shape()
|
feature_map_shape = encoder_to_check.get_output_shape()
|
||||||
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
|
|||||||
# Shape is (B, N, C, H, W) - add time dimension
|
# Shape is (B, N, C, H, W) - add time dimension
|
||||||
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
||||||
|
|
||||||
vision_config = self.config.observation_encoder.vision
|
if self.config.use_separate_encoder_per_camera:
|
||||||
if vision_config.use_separate_encoder_per_camera:
|
|
||||||
# Process each camera with its own encoder
|
# Process each camera with its own encoder
|
||||||
camera_features = []
|
camera_features = []
|
||||||
|
|
||||||
|
|||||||
@@ -308,32 +308,29 @@ class TransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DiffusionTransformer(nn.Module):
|
class DiffusionTransformer(nn.Module):
|
||||||
"""
|
"""Transformer-based diffusion noise prediction model."""
|
||||||
Transformer-based diffusion noise prediction model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, conditioning_dim: int):
|
def __init__(self, config, conditioning_dim: int):
|
||||||
"""Initialize transformer for noise prediction.
|
"""Initialize transformer for noise prediction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Multi-Task DiTConfig with transformer parameters
|
config: MultiTaskDiTConfig with transformer parameters
|
||||||
conditioning_dim: Dimension of concatenated observation features
|
conditioning_dim: Dimension of concatenated observation features
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer_config = config.transformer
|
|
||||||
self.conditioning_dim = conditioning_dim
|
self.conditioning_dim = conditioning_dim
|
||||||
|
|
||||||
self.action_dim = config.action_feature.shape[0]
|
self.action_dim = config.action_feature.shape[0]
|
||||||
self.horizon = config.horizon
|
self.horizon = config.horizon
|
||||||
self.hidden_size = self.transformer_config.hidden_dim
|
self.hidden_size = config.hidden_dim
|
||||||
self.num_layers = self.transformer_config.num_layers
|
self.num_layers = config.num_layers
|
||||||
self.num_heads = self.transformer_config.num_heads
|
self.num_heads = config.num_heads
|
||||||
self.dropout = self.transformer_config.dropout
|
self.dropout = config.dropout
|
||||||
self.use_rope = self.transformer_config.use_rope
|
self.use_rope = config.use_rope
|
||||||
|
|
||||||
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
self.timestep_embed_dim = config.timestep_embed_dim
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||||
@@ -347,7 +344,7 @@ class DiffusionTransformer(nn.Module):
|
|||||||
# Project action dimensions to hidden size
|
# Project action dimensions to hidden size
|
||||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||||
|
|
||||||
if self.transformer_config.use_positional_encoding:
|
if config.use_positional_encoding:
|
||||||
# Learnable positional embeddings for sequence positions (absolute encoding)
|
# Learnable positional embeddings for sequence positions (absolute encoding)
|
||||||
self.pos_embedding = nn.Parameter(
|
self.pos_embedding = nn.Parameter(
|
||||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||||
@@ -363,8 +360,8 @@ class DiffusionTransformer(nn.Module):
|
|||||||
num_features=self.cond_dim,
|
num_features=self.cond_dim,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
use_rope=self.use_rope,
|
use_rope=self.use_rope,
|
||||||
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
|
max_seq_len=self.horizon,
|
||||||
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
|
rope_base=config.rope_base,
|
||||||
)
|
)
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
@@ -377,9 +374,7 @@ class DiffusionTransformer(nn.Module):
|
|||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
def _initialize_weights(self):
|
def _initialize_weights(self):
|
||||||
"""
|
"""Zero-initialize final linear layer of adaLN_modulation for training stability."""
|
||||||
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
|
|
||||||
"""
|
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||||
|
|||||||
@@ -28,11 +28,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
DiffusionConfig,
|
|
||||||
FlowMatchingConfig,
|
|
||||||
MultiTaskDiTConfig,
|
|
||||||
)
|
|
||||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||||
@@ -108,13 +104,12 @@ def create_config(
|
|||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
n_action_steps=n_action_steps,
|
n_action_steps=n_action_steps,
|
||||||
|
# Use smaller model for faster tests
|
||||||
|
hidden_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_heads=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use smaller model for faster tests
|
|
||||||
config.transformer.hidden_dim = 128
|
|
||||||
config.transformer.num_layers = 2
|
|
||||||
config.transformer.num_heads = 4
|
|
||||||
|
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
horizon = 16
|
horizon = 16
|
||||||
n_action_steps = 8
|
n_action_steps = 8
|
||||||
|
|
||||||
config = create_config(
|
input_features = {
|
||||||
state_dim=state_dim,
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||||
action_dim=action_dim,
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
|
||||||
|
config = MultiTaskDiTConfig(
|
||||||
|
input_features=input_features,
|
||||||
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
n_action_steps=n_action_steps,
|
n_action_steps=n_action_steps,
|
||||||
)
|
# Use diffusion objective
|
||||||
config.objective = DiffusionConfig(
|
objective="diffusion",
|
||||||
noise_scheduler_type="DDPM",
|
noise_scheduler_type="DDPM",
|
||||||
num_train_timesteps=100,
|
num_train_timesteps=100,
|
||||||
num_inference_steps=10,
|
num_inference_steps=10,
|
||||||
|
# Smaller model for tests
|
||||||
|
hidden_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_heads=4,
|
||||||
)
|
)
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
horizon = 16
|
horizon = 16
|
||||||
n_action_steps = 8
|
n_action_steps = 8
|
||||||
|
|
||||||
config = create_config(
|
input_features = {
|
||||||
state_dim=state_dim,
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||||
action_dim=action_dim,
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
|
||||||
|
config = MultiTaskDiTConfig(
|
||||||
|
input_features=input_features,
|
||||||
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
n_action_steps=n_action_steps,
|
n_action_steps=n_action_steps,
|
||||||
)
|
# Use flow matching objective
|
||||||
config.objective = FlowMatchingConfig(
|
objective="flow_matching",
|
||||||
sigma_min=0.0,
|
sigma_min=0.0,
|
||||||
num_integration_steps=10, # Use fewer steps for faster tests
|
num_integration_steps=10, # Fewer steps for faster tests
|
||||||
integration_method="euler",
|
integration_method="euler",
|
||||||
|
# Smaller model for tests
|
||||||
|
hidden_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_heads=4,
|
||||||
)
|
)
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
|
|||||||
# Second group is vision encoder params with different lr
|
# Second group is vision encoder params with different lr
|
||||||
assert "params" in param_groups[1]
|
assert "params" in param_groups[1]
|
||||||
assert "lr" in param_groups[1]
|
assert "lr" in param_groups[1]
|
||||||
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
|
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
|
||||||
assert param_groups[1]["lr"] == expected_lr
|
assert param_groups[1]["lr"] == expected_lr
|
||||||
|
|||||||
Reference in New Issue
Block a user