simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

This commit is contained in:
Bryson Jones
2025-12-10 11:45:59 -08:00
parent cdacc090cd
commit 103230c64c
7 changed files with 242 additions and 454 deletions
@@ -0,0 +1 @@
#!/usr/bin/env python
@@ -15,8 +15,7 @@
# limitations under the License.
from dataclasses import dataclass, field
import draccus
from typing import Literal
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@@ -24,295 +23,84 @@ from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@dataclass
class ObjectiveConfig(draccus.ChoiceRegistry):
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
pass
@ObjectiveConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(ObjectiveConfig):
"""Configuration for standard diffusion model training and inference.
These parameters control the noise scheduling and denoising process for
standard DDPM/DDIM diffusion models.
"""
objective_name: str = field(default="diffusion", init=False)
# Noise scheduler configuration - controls diffusion process
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
beta_start: float = 0.0001 # Small initial noise level
beta_end: float = 0.02 # Moderate final noise level
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
clip_sample: bool = True # Prevent extreme action values
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
# Inference configuration
num_inference_steps: int | None = None # Default to num_train_timesteps
def __post_init__(self):
"""Validate diffusion-specific parameters."""
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
f"got {self.beta_start}, {self.beta_end}"
)
@dataclass
class TimestepSamplingConfig(draccus.ChoiceRegistry):
"""Base configuration for timestep sampling strategies during training."""
pass
@TimestepSamplingConfig.register_subclass("uniform")
@dataclass
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
"""Uniform timestep sampling from [0, 1]."""
strategy_name: str = field(default="uniform", init=False)
@TimestepSamplingConfig.register_subclass("beta")
@dataclass
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
"""Beta distribution timestep sampling.
Samples from Beta distribution emphasizing low timesteps (high noise).
This was inspired on the work from Physical Intelligence PI-0 model,
where they suggested the beta distribution for sampling timesteps
during training improved sample quality.
"""
strategy_name: str = field(default="beta", init=False)
s: float = 0.999 # Max timestep threshold for beta sampling
alpha: float = 1.5 # Beta distribution alpha parameter
beta: float = 1.0 # Beta distribution beta parameter
def __post_init__(self):
if not (0.0 < self.s <= 1.0):
raise ValueError(f"s must be in (0, 1], got {self.s}")
if self.alpha <= 0:
raise ValueError(f"alpha must be positive, got {self.alpha}")
if self.beta <= 0:
raise ValueError(f"beta must be positive, got {self.beta}")
@ObjectiveConfig.register_subclass("flow_matching")
@dataclass
class FlowMatchingConfig(ObjectiveConfig):
"""Configuration for flow matching training and inference.
These parameters control the velocity field learning and ODE integration
process for flow matching models.
"""
objective_name: str = field(default="flow_matching", init=False)
# Flow path construction
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
# ODE integration for inference
num_integration_steps: int = (
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
)
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
# Timestep sampling strategy for training
# Beta distribution found to be the most effective in practice, so it is the default
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
def __post_init__(self):
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
@dataclass
class TransformerConfig:
"""Configuration for Transformer-based prediction model.
These parameters control the transformer architecture used for noise/velocity
prediction in diffusion and flow matching models.
"""
# Transformer architecture parameters
hidden_dim: int = 512 # Hidden dimension of transformer
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Whether to use absolute positional encoding
diffusion_step_embed_dim: int = 256 # Timestep embedding size
# RoPE (Rotary Position Embedding) configuration
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
rope_base: float = 10000.0 # Base frequency for RoPE computation
def __post_init__(self):
"""Validate Transformer-specific parameters."""
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
if self.diffusion_step_embed_dim <= 0:
raise ValueError("diffusion_step_embed_dim must be positive")
@dataclass
class VisionEncoderConfig:
"""Configuration for CLIP vision encoder.
Uses CLIPVisionModel from transformers library.
CLS token usage is handled automatically.
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
All image preprocessing is centralized here:
1. Resize (optional) - resize images to target resolution
2. Crop (optional) - crop after resize, must be smaller than resize_shape
3. Random crop - whether to use random cropping during training
Any CLIP model from transformers can be used. Examples:
- openai/clip-vit-base-patch16 (default, 768 dims)
- openai/clip-vit-large-patch14 (1024 dims)
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
"""
model_name: str = "openai/clip-vit-base-patch16"
use_separate_encoder_per_camera: bool = False
# Learning rate multiplier for vision encoder parameters
# Vision encoder learning rate = optimizer_lr * lr_multiplier
lr_multiplier: float = 0.1
# Image preprocessing (centralized)
resize_shape: tuple[int, int] | None = None
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
crop_is_random: bool = True
def __post_init__(self):
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model_name.lower():
raise ValueError(
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
)
if (
self.resize_shape
and self.crop_shape
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
):
raise ValueError(
f"crop_shape {self.crop_shape} must be smaller than or equal to "
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
)
@dataclass
class TextEncoderConfig:
"""Configuration for CLIP text encoder.
Uses CLIP's text encoder to embed task descriptions, which are then
used to condition the policy. The text embeddings are processed by
a learnable projection layer before being concatenated into the
conditioning vector.
Any HuggingFace CLIP model can be used. Examples:
- openai/clip-vit-base-patch16 (default)
- openai/clip-vit-large-patch14
"""
model: str = "openai/clip-vit-base-patch16"
def __post_init__(self):
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model.lower():
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
@dataclass
class ObservationEncoderConfig:
"""Top-level configuration for observation encoding.
This config combines:
- Vision encoding (required): CLIP vision encoder from transformers
"""
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
# Temporal structure - controls how the policy processes time and predicts actions
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t)
horizon: int = 100 # predicted action steps into the future
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Normalization strategy - critical for diffusion model performance
# Objective Selection
objective: Literal["diffusion", "flow_matching"] = "diffusion"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: Literal["uniform", "beta"] = "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
transformer: TransformerConfig = field(default_factory=TransformerConfig)
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
# training optimizer and scheduler hyperparameters
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0 # No warmup found to be optimal
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
@@ -320,11 +108,78 @@ class MultiTaskDiTConfig(PreTrainedConfig):
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
def get_optimizer_preset(self) -> AdamConfig:
"""Return Adam optimizer configuration optimized for diffusion training.
self._validate()
Note: Vision encoder learning rate is set separately via get_optim_params.
"""
def _validate(self):
"""Validate configuration parameters."""
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
raise ValueError(
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
)
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
@@ -333,7 +188,6 @@ class MultiTaskDiTConfig(PreTrainedConfig):
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
"""Return learning rate scheduler configuration."""
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
@@ -341,56 +195,41 @@ class MultiTaskDiTConfig(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state
# This allows for testing and simple state-only policies
# Validate crop shape fits within image dimensions
crop_shape = self.observation_encoder.vision.crop_shape
if crop_shape is not None:
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]:
if (
self.image_crop_shape[0] > image_ft.shape[1]
or self.image_crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
f"for '{key}'"
)
# Ensure all images have same shape (current limitation)
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
if image_ft.shape != first_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def model_objective(self) -> str:
return self.objective.objective_name
@property
def is_diffusion(self) -> bool:
return isinstance(self.objective, DiffusionConfig)
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return isinstance(self.objective, FlowMatchingConfig)
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
"""Get the objective-specific configuration with proper typing."""
return self.objective
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
"""Delta indices for stacking observations. Provides temporal context."""
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
"""Delta indices for action horizon prediction."""
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
"""Indices for reward deltas (not used in diffusion policy)."""
return None
@@ -52,23 +52,22 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
action_dim = config.action_feature.shape[0]
horizon = config.horizon
self.model_objective = config.model_objective
if config.is_diffusion:
self.objective = DiffusionObjective(
config.get_objective_config(),
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config.get_objective_config(),
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported model_objective: {self.model_objective}")
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
@@ -90,7 +89,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
@@ -118,8 +117,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
if self.config.observation_encoder.text:
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
# Always include task queue for text conditioning
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training or validation."""
@@ -14,14 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains a base objective class, and implementation of the objective
classes for use in the Multi-Task Diffusion Transformer Policy.
"""Objective implementations for Multi-Task DiT policy.
Architecture:
- BaseObjective: Abstract interface definition
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
- FlowMatchingObjective: Implements flow matching objective
- DiffusionObjective: Standard DDPM/DDIM diffusion
- FlowMatchingObjective: Flow matching with ODE integration
"""
from abc import ABC, abstractmethod
@@ -35,10 +31,7 @@ from torch import Tensor
class BaseObjective(ABC):
"""
Base class for objectives used in Multi-Task DiT policy.
Defines the interface for training loss computation and conditional sampling.
"""
"""Base class for objectives used in Multi-Task DiT policy."""
def __init__(self, config, action_dim: int, horizon: int):
self.config = config
@@ -47,38 +40,17 @@ class BaseObjective(ABC):
@abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute training loss for the objective.
Args:
model: The prediction network
batch: Training batch with observations and actions
conditioning_vec: Encoded observation features for conditioning
Returns:
Scalar loss tensor
"""
"""Compute training loss."""
pass
@abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate action samples conditioned on embedded observation features.
Args:
model: The prediction network
batch_size: The number of samples to generate
conditioning_vec: Encoded observation features for conditioning
Returns:
Generated action sequences (batch_size, horizon, action_dim)
"""
"""Generate action samples conditioned on observations."""
pass
class DiffusionObjective(BaseObjective):
"""Standard diffusion (DDPM/DDIM) objective implementation.
Contains the noise scheduler, training loss, and conditional sampling.
"""
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
@@ -90,8 +62,8 @@ class DiffusionObjective(BaseObjective):
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": getattr(config, "clip_sample", True),
"clip_sample_range": getattr(config, "clip_sample_range", 1.0),
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
@@ -105,7 +77,7 @@ class DiffusionObjective(BaseObjective):
# Inference steps default to training steps if not provided
self.num_inference_steps = (
config.num_inference_steps
if getattr(config, "num_inference_steps", None) is not None
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
@@ -161,40 +133,29 @@ class DiffusionObjective(BaseObjective):
class FlowMatchingObjective(BaseObjective):
"""
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
noise to data. This basically interpolates as path between noise and a trained distribution
with a set target velocity.
"""
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
"""Sample timesteps according to configured strategy.
Uniform: Sample t uniformly from [0,1]
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
"""
if self.config.timestep_sampling.strategy_name == "uniform":
"""Sample timesteps according to configured strategy."""
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling.strategy_name == "beta":
elif self.config.timestep_sampling_strategy == "beta":
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
# This emphasizes t near 0 (high noise) when α > β
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling.s * (1.0 - u)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}")
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute flow matching training loss.
Trains the model to predict the velocity field along linear interpolation paths.
"""
"""Compute flow matching training loss."""
data = batch["action"] # Clean action sequences (B, T, D)
batch_size = data.shape[0]
device = data.device
@@ -217,10 +178,7 @@ class FlowMatchingObjective(BaseObjective):
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate actions by integrating the learned velocity field via ODE.
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
"""
"""Generate actions by integrating the learned velocity field via ODE."""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
@@ -244,9 +202,7 @@ class FlowMatchingObjective(BaseObjective):
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
"""
"""Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)"""
x = x_init
for i in range(len(time_grid) - 1):
@@ -268,23 +224,10 @@ class FlowMatchingObjective(BaseObjective):
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""4th-order Runge-Kutta integration.
Uses 4 velocity evaluations per step:
k1 = v(x, t)
k2 = v(x + dt·k1/2, t + dt/2)
k3 = v(x + dt·k2/2, t + dt/2)
k4 = v(x + dt·k3, t + dt)
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
4x slower than Euler but more accurate.
Note: In practice, this seems to not matter much
"""
"""4th-order Runge-Kutta integration."""
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
"""dynamics helper to get velocity at (x, t)"""
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
class CLIPTextEncoder(nn.Module):
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
"""
"""CLIP text encoder with frozen weights and a learnable projection layer."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
vision_config = config.observation_encoder.vision
self._setup_preprocessing(vision_config)
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) # Preserve ordering
self.camera_names = list(config.image_features.keys())
if vision_config.use_separate_encoder_per_camera:
if config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
else:
self.env_state_dim = 0
text_config = config.observation_encoder.text
self.text_dim = config.transformer.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
return images
def _setup_preprocessing(self, vision_config):
def _setup_preprocessing(self, config):
"""Setup image preprocessing transforms."""
if vision_config.resize_shape is not None:
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=vision_config.resize_shape,
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if vision_config.crop_shape is not None:
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
if vision_config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
# Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
# Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape()
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
# Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
vision_config = self.config.observation_encoder.vision
if vision_config.use_separate_encoder_per_camera:
if self.config.use_separate_encoder_per_camera:
# Process each camera with its own encoder
camera_features = []
@@ -308,32 +308,29 @@ class TransformerBlock(nn.Module):
class DiffusionTransformer(nn.Module):
"""
Transformer-based diffusion noise prediction model.
"""
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
"""Initialize transformer for noise prediction.
Args:
config: Multi-Task DiTConfig with transformer parameters
config: MultiTaskDiTConfig with transformer parameters
conditioning_dim: Dimension of concatenated observation features
"""
super().__init__()
self.config = config
self.transformer_config = config.transformer
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = self.transformer_config.hidden_dim
self.num_layers = self.transformer_config.num_layers
self.num_heads = self.transformer_config.num_heads
self.dropout = self.transformer_config.dropout
self.use_rope = self.transformer_config.use_rope
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
@@ -347,7 +344,7 @@ class DiffusionTransformer(nn.Module):
# Project action dimensions to hidden size
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if self.transformer_config.use_positional_encoding:
if config.use_positional_encoding:
# Learnable positional embeddings for sequence positions (absolute encoding)
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
@@ -363,8 +360,8 @@ class DiffusionTransformer(nn.Module):
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
@@ -377,9 +374,7 @@ class DiffusionTransformer(nn.Module):
self._initialize_weights()
def _initialize_weights(self):
"""
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
"""
"""Zero-initialize final linear layer of adaLN_modulation for training stability."""
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
@@ -28,11 +28,7 @@ import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
DiffusionConfig,
FlowMatchingConfig,
MultiTaskDiTConfig,
)
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -108,13 +104,12 @@ def create_config(
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use smaller model for faster tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
# Use smaller model for faster tests
config.transformer.hidden_dim = 128
config.transformer.num_layers = 2
config.transformer.num_heads = 4
config.validate_features()
return config
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = DiffusionConfig(
# Use diffusion objective
objective="diffusion",
noise_scheduler_type="DDPM",
num_train_timesteps=100,
num_inference_steps=10,
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = FlowMatchingConfig(
# Use flow matching objective
objective="flow_matching",
sigma_min=0.0,
num_integration_steps=10, # Use fewer steps for faster tests
num_integration_steps=10, # Fewer steps for faster tests
integration_method="euler",
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
# Second group is vision encoder params with different lr
assert "params" in param_groups[1]
assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
assert param_groups[1]["lr"] == expected_lr