simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

This commit is contained in:
Bryson Jones
2025-12-10 11:45:59 -08:00
parent cdacc090cd
commit 103230c64c
7 changed files with 242 additions and 454 deletions
@@ -0,0 +1 @@
#!/usr/bin/env python
@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
import draccus
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
@@ -24,295 +23,84 @@ from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig from lerobot.optim.schedulers import DiffuserSchedulerConfig
@dataclass
class ObjectiveConfig(draccus.ChoiceRegistry):
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
pass
@ObjectiveConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(ObjectiveConfig):
"""Configuration for standard diffusion model training and inference.
These parameters control the noise scheduling and denoising process for
standard DDPM/DDIM diffusion models.
"""
objective_name: str = field(default="diffusion", init=False)
# Noise scheduler configuration - controls diffusion process
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
beta_start: float = 0.0001 # Small initial noise level
beta_end: float = 0.02 # Moderate final noise level
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
clip_sample: bool = True # Prevent extreme action values
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
# Inference configuration
num_inference_steps: int | None = None # Default to num_train_timesteps
def __post_init__(self):
"""Validate diffusion-specific parameters."""
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
f"got {self.beta_start}, {self.beta_end}"
)
@dataclass
class TimestepSamplingConfig(draccus.ChoiceRegistry):
"""Base configuration for timestep sampling strategies during training."""
pass
@TimestepSamplingConfig.register_subclass("uniform")
@dataclass
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
"""Uniform timestep sampling from [0, 1]."""
strategy_name: str = field(default="uniform", init=False)
@TimestepSamplingConfig.register_subclass("beta")
@dataclass
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
"""Beta distribution timestep sampling.
Samples from Beta distribution emphasizing low timesteps (high noise).
This was inspired on the work from Physical Intelligence PI-0 model,
where they suggested the beta distribution for sampling timesteps
during training improved sample quality.
"""
strategy_name: str = field(default="beta", init=False)
s: float = 0.999 # Max timestep threshold for beta sampling
alpha: float = 1.5 # Beta distribution alpha parameter
beta: float = 1.0 # Beta distribution beta parameter
def __post_init__(self):
if not (0.0 < self.s <= 1.0):
raise ValueError(f"s must be in (0, 1], got {self.s}")
if self.alpha <= 0:
raise ValueError(f"alpha must be positive, got {self.alpha}")
if self.beta <= 0:
raise ValueError(f"beta must be positive, got {self.beta}")
@ObjectiveConfig.register_subclass("flow_matching")
@dataclass
class FlowMatchingConfig(ObjectiveConfig):
"""Configuration for flow matching training and inference.
These parameters control the velocity field learning and ODE integration
process for flow matching models.
"""
objective_name: str = field(default="flow_matching", init=False)
# Flow path construction
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
# ODE integration for inference
num_integration_steps: int = (
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
)
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
# Timestep sampling strategy for training
# Beta distribution found to be the most effective in practice, so it is the default
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
def __post_init__(self):
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
@dataclass
class TransformerConfig:
"""Configuration for Transformer-based prediction model.
These parameters control the transformer architecture used for noise/velocity
prediction in diffusion and flow matching models.
"""
# Transformer architecture parameters
hidden_dim: int = 512 # Hidden dimension of transformer
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Whether to use absolute positional encoding
diffusion_step_embed_dim: int = 256 # Timestep embedding size
# RoPE (Rotary Position Embedding) configuration
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
rope_base: float = 10000.0 # Base frequency for RoPE computation
def __post_init__(self):
"""Validate Transformer-specific parameters."""
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
if self.diffusion_step_embed_dim <= 0:
raise ValueError("diffusion_step_embed_dim must be positive")
@dataclass
class VisionEncoderConfig:
"""Configuration for CLIP vision encoder.
Uses CLIPVisionModel from transformers library.
CLS token usage is handled automatically.
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
All image preprocessing is centralized here:
1. Resize (optional) - resize images to target resolution
2. Crop (optional) - crop after resize, must be smaller than resize_shape
3. Random crop - whether to use random cropping during training
Any CLIP model from transformers can be used. Examples:
- openai/clip-vit-base-patch16 (default, 768 dims)
- openai/clip-vit-large-patch14 (1024 dims)
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
"""
model_name: str = "openai/clip-vit-base-patch16"
use_separate_encoder_per_camera: bool = False
# Learning rate multiplier for vision encoder parameters
# Vision encoder learning rate = optimizer_lr * lr_multiplier
lr_multiplier: float = 0.1
# Image preprocessing (centralized)
resize_shape: tuple[int, int] | None = None
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
crop_is_random: bool = True
def __post_init__(self):
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model_name.lower():
raise ValueError(
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
)
if (
self.resize_shape
and self.crop_shape
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
):
raise ValueError(
f"crop_shape {self.crop_shape} must be smaller than or equal to "
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
)
@dataclass
class TextEncoderConfig:
"""Configuration for CLIP text encoder.
Uses CLIP's text encoder to embed task descriptions, which are then
used to condition the policy. The text embeddings are processed by
a learnable projection layer before being concatenated into the
conditioning vector.
Any HuggingFace CLIP model can be used. Examples:
- openai/clip-vit-base-patch16 (default)
- openai/clip-vit-large-patch14
"""
model: str = "openai/clip-vit-base-patch16"
def __post_init__(self):
# Validate that model name contains "clip" to ensure correct encoder type
if "clip" not in self.model.lower():
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
@dataclass
class ObservationEncoderConfig:
"""Top-level configuration for observation encoding.
This config combines:
- Vision encoding (required): CLIP vision encoder from transformers
"""
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
@PreTrainedConfig.register_subclass("multi_task_dit") @PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass @dataclass
class MultiTaskDiTConfig(PreTrainedConfig): class MultiTaskDiTConfig(PreTrainedConfig):
""" """Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
""" """
# Temporal structure - controls how the policy processes time and predicts actions n_obs_steps: int = 2 # Number of observation steps for temporal context
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t) horizon: int = 32 # Number of action steps to predict
horizon: int = 100 # predicted action steps into the future n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
# Normalization strategy - critical for diffusion model performance # Objective Selection
objective: Literal["diffusion", "flow_matching"] = "diffusion"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: Literal["uniform", "beta"] = "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision "VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping "STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process "ACTION": NormalizationMode.MIN_MAX,
} }
) )
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1 # Training/Optimizer
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
transformer: TransformerConfig = field(default_factory=TransformerConfig)
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
# training optimizer and scheduler hyperparameters
optimizer_lr: float = 2e-5 optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999) optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8 optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine" scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0 # No warmup found to be optimal scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@@ -320,11 +108,78 @@ class MultiTaskDiTConfig(PreTrainedConfig):
if self.drop_n_last_frames is None: if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1 self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
def get_optimizer_preset(self) -> AdamConfig: self._validate()
"""Return Adam optimizer configuration optimized for diffusion training.
Note: Vision encoder learning rate is set separately via get_optim_params. def _validate(self):
""" """Validate configuration parameters."""
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
raise ValueError(
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
)
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig( return AdamConfig(
lr=self.optimizer_lr, lr=self.optimizer_lr,
betas=self.optimizer_betas, betas=self.optimizer_betas,
@@ -333,7 +188,6 @@ class MultiTaskDiTConfig(PreTrainedConfig):
) )
def get_scheduler_preset(self) -> DiffuserSchedulerConfig: def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
"""Return learning rate scheduler configuration."""
return DiffuserSchedulerConfig( return DiffuserSchedulerConfig(
name=self.scheduler_name, name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps, num_warmup_steps=self.scheduler_warmup_steps,
@@ -341,56 +195,41 @@ class MultiTaskDiTConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate that required input features are present and properly configured.""" """Validate that required input features are present and properly configured."""
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state if self.image_crop_shape is not None:
# This allows for testing and simple state-only policies
# Validate crop shape fits within image dimensions
crop_shape = self.observation_encoder.vision.crop_shape
if crop_shape is not None:
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]: if (
self.image_crop_shape[0] > image_ft.shape[1]
or self.image_crop_shape[1] > image_ft.shape[2]
):
raise ValueError( raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {crop_shape} " f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
f"for `crop_shape` and {image_ft.shape} for " f"for '{key}'"
f"`{key}`."
) )
# Ensure all images have same shape (current limitation)
if len(self.image_features) > 0: if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items())) first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_ft.shape:
raise ValueError( raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
) )
@property
def model_objective(self) -> str:
return self.objective.objective_name
@property @property
def is_diffusion(self) -> bool: def is_diffusion(self) -> bool:
return isinstance(self.objective, DiffusionConfig) return self.objective == "diffusion"
@property @property
def is_flow_matching(self) -> bool: def is_flow_matching(self) -> bool:
return isinstance(self.objective, FlowMatchingConfig) return self.objective == "flow_matching"
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
"""Get the objective-specific configuration with proper typing."""
return self.objective
@property @property
def observation_delta_indices(self) -> list: def observation_delta_indices(self) -> list:
"""Delta indices for stacking observations. Provides temporal context."""
return list(range(1 - self.n_obs_steps, 1)) return list(range(1 - self.n_obs_steps, 1))
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
"""Delta indices for action horizon prediction."""
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property @property
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:
"""Indices for reward deltas (not used in diffusion policy)."""
return None return None
@@ -52,23 +52,22 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
action_dim = config.action_feature.shape[0] action_dim = config.action_feature.shape[0]
horizon = config.horizon horizon = config.horizon
self.model_objective = config.model_objective
if config.is_diffusion: if config.is_diffusion:
self.objective = DiffusionObjective( self.objective = DiffusionObjective(
config.get_objective_config(), config,
action_dim=action_dim, action_dim=action_dim,
horizon=horizon, horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding, do_mask_loss_for_padding=config.do_mask_loss_for_padding,
) )
elif config.is_flow_matching: elif config.is_flow_matching:
self.objective = FlowMatchingObjective( self.objective = FlowMatchingObjective(
config.get_objective_config(), config,
action_dim=action_dim, action_dim=action_dim,
horizon=horizon, horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding, do_mask_loss_for_padding=config.do_mask_loss_for_padding,
) )
else: else:
raise ValueError(f"Unsupported model_objective: {self.model_objective}") raise ValueError(f"Unsupported objective: {config.objective}")
self.reset() self.reset()
@@ -90,7 +89,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
{"params": non_vision_params}, {"params": non_vision_params},
{ {
"params": vision_encoder_params, "params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier, "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
}, },
] ]
@@ -118,8 +117,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
if self.config.env_state_feature: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
if self.config.observation_encoder.text: # Always include task queue for text conditioning
self._queues["task"] = deque(maxlen=self.config.n_obs_steps) self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
@@ -14,14 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """Objective implementations for Multi-Task DiT policy.
This module contains a base objective class, and implementation of the objective
classes for use in the Multi-Task Diffusion Transformer Policy.
Architecture: - DiffusionObjective: Standard DDPM/DDIM diffusion
- BaseObjective: Abstract interface definition - FlowMatchingObjective: Flow matching with ODE integration
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
- FlowMatchingObjective: Implements flow matching objective
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -35,10 +31,7 @@ from torch import Tensor
class BaseObjective(ABC): class BaseObjective(ABC):
""" """Base class for objectives used in Multi-Task DiT policy."""
Base class for objectives used in Multi-Task DiT policy.
Defines the interface for training loss computation and conditional sampling.
"""
def __init__(self, config, action_dim: int, horizon: int): def __init__(self, config, action_dim: int, horizon: int):
self.config = config self.config = config
@@ -47,38 +40,17 @@ class BaseObjective(ABC):
@abstractmethod @abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute training loss for the objective. """Compute training loss."""
Args:
model: The prediction network
batch: Training batch with observations and actions
conditioning_vec: Encoded observation features for conditioning
Returns:
Scalar loss tensor
"""
pass pass
@abstractmethod @abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate action samples conditioned on embedded observation features. """Generate action samples conditioned on observations."""
Args:
model: The prediction network
batch_size: The number of samples to generate
conditioning_vec: Encoded observation features for conditioning
Returns:
Generated action sequences (batch_size, horizon, action_dim)
"""
pass pass
class DiffusionObjective(BaseObjective): class DiffusionObjective(BaseObjective):
"""Standard diffusion (DDPM/DDIM) objective implementation. """Standard diffusion (DDPM/DDIM) objective implementation."""
Contains the noise scheduler, training loss, and conditional sampling.
"""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon) super().__init__(config, action_dim, horizon)
@@ -90,8 +62,8 @@ class DiffusionObjective(BaseObjective):
"beta_start": config.beta_start, "beta_start": config.beta_start,
"beta_end": config.beta_end, "beta_end": config.beta_end,
"beta_schedule": config.beta_schedule, "beta_schedule": config.beta_schedule,
"clip_sample": getattr(config, "clip_sample", True), "clip_sample": config.clip_sample,
"clip_sample_range": getattr(config, "clip_sample_range", 1.0), "clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type, "prediction_type": config.prediction_type,
} }
@@ -105,7 +77,7 @@ class DiffusionObjective(BaseObjective):
# Inference steps default to training steps if not provided # Inference steps default to training steps if not provided
self.num_inference_steps = ( self.num_inference_steps = (
config.num_inference_steps config.num_inference_steps
if getattr(config, "num_inference_steps", None) is not None if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps else self.noise_scheduler.config.num_train_timesteps
) )
@@ -161,40 +133,29 @@ class DiffusionObjective(BaseObjective):
class FlowMatchingObjective(BaseObjective): class FlowMatchingObjective(BaseObjective):
""" """Flow matching objective: trains a model to predict velocity fields."""
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
noise to data. This basically interpolates as path between noise and a trained distribution
with a set target velocity.
"""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon) super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
"""Sample timesteps according to configured strategy. """Sample timesteps according to configured strategy."""
if self.config.timestep_sampling_strategy == "uniform":
Uniform: Sample t uniformly from [0,1]
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
"""
if self.config.timestep_sampling.strategy_name == "uniform":
return torch.rand(batch_size, device=device) return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling.strategy_name == "beta": elif self.config.timestep_sampling_strategy == "beta":
# Sample u ~ Beta(α, β) then transform: t = s(1-u) # Sample u ~ Beta(α, β) then transform: t = s(1-u)
# This emphasizes t near 0 (high noise) when α > β # This emphasizes t near 0 (high noise) when α > β
beta_dist = torch.distributions.Beta( beta_dist = torch.distributions.Beta(
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
) )
u = beta_dist.sample((batch_size,)).to(device) u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling.s * (1.0 - u) return self.config.timestep_sampling_s * (1.0 - u)
else: else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}") raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute flow matching training loss. """Compute flow matching training loss."""
Trains the model to predict the velocity field along linear interpolation paths.
"""
data = batch["action"] # Clean action sequences (B, T, D) data = batch["action"] # Clean action sequences (B, T, D)
batch_size = data.shape[0] batch_size = data.shape[0]
device = data.device device = data.device
@@ -217,10 +178,7 @@ class FlowMatchingObjective(BaseObjective):
return loss.mean() return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate actions by integrating the learned velocity field via ODE. """Generate actions by integrating the learned velocity field via ODE."""
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
"""
device = next(model.parameters()).device device = next(model.parameters()).device
dtype = next(model.parameters()).dtype dtype = next(model.parameters()).dtype
@@ -244,9 +202,7 @@ class FlowMatchingObjective(BaseObjective):
def _euler_integrate( def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor: ) -> Tensor:
""" """Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)"""
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
"""
x = x_init x = x_init
for i in range(len(time_grid) - 1): for i in range(len(time_grid) - 1):
@@ -268,23 +224,10 @@ class FlowMatchingObjective(BaseObjective):
def _rk4_integrate( def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor: ) -> Tensor:
"""4th-order Runge-Kutta integration. """4th-order Runge-Kutta integration."""
Uses 4 velocity evaluations per step:
k1 = v(x, t)
k2 = v(x + dt·k1/2, t + dt/2)
k3 = v(x + dt·k2/2, t + dt/2)
k4 = v(x + dt·k3, t + dt)
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
4x slower than Euler but more accurate.
Note: In practice, this seems to not matter much
"""
x = x_init x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor: def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
"""dynamics helper to get velocity at (x, t)"""
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device) t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad(): with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec) return model(x_val, t_batch, conditioning_vec=conditioning_vec)
@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
class CLIPTextEncoder(nn.Module): class CLIPTextEncoder(nn.Module):
"""Supports any HuggingFace CLIP model. The encoder weights are frozen, """CLIP text encoder with frozen weights and a learnable projection layer."""
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__() super().__init__()
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
vision_config = config.observation_encoder.vision
self._setup_preprocessing(vision_config) self._setup_preprocessing(config)
if config.image_features: if config.image_features:
self.num_cameras = len(config.image_features) self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) # Preserve ordering self.camera_names = list(config.image_features.keys())
if vision_config.use_separate_encoder_per_camera: if config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList( self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names] [CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
) )
self.vision_encoder = None self.vision_encoder = None
else: else:
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name) self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None self.vision_encoders = None
else: else:
self.vision_encoder = None self.vision_encoder = None
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
else: else:
self.env_state_dim = 0 self.env_state_dim = 0
text_config = config.observation_encoder.text self.text_dim = config.hidden_dim
self.text_dim = config.transformer.hidden_dim self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
self._setup_vector_output() self._setup_vector_output()
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
return images return images
def _setup_preprocessing(self, vision_config): def _setup_preprocessing(self, config):
"""Setup image preprocessing transforms.""" """Setup image preprocessing transforms."""
if vision_config.resize_shape is not None: if config.image_resize_shape is not None:
self.do_resize = True self.do_resize = True
self.resize = torchvision.transforms.Resize( self.resize = torchvision.transforms.Resize(
size=vision_config.resize_shape, size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True, antialias=True,
) )
else: else:
self.do_resize = False self.do_resize = False
if vision_config.crop_shape is not None:
if config.image_crop_shape is not None:
self.do_crop = True self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if vision_config.crop_is_random: if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
# Vision features - get CLS token feature dimension # Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None: if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values())) encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
# Get output shape from encoder (deterministic for CLS tokens) # Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape() feature_map_shape = encoder_to_check.get_output_shape()
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
# Shape is (B, N, C, H, W) - add time dimension # Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W) images = images.unsqueeze(1) # (B, 1, N, C, H, W)
vision_config = self.config.observation_encoder.vision if self.config.use_separate_encoder_per_camera:
if vision_config.use_separate_encoder_per_camera:
# Process each camera with its own encoder # Process each camera with its own encoder
camera_features = [] camera_features = []
@@ -308,32 +308,29 @@ class TransformerBlock(nn.Module):
class DiffusionTransformer(nn.Module): class DiffusionTransformer(nn.Module):
""" """Transformer-based diffusion noise prediction model."""
Transformer-based diffusion noise prediction model.
"""
def __init__(self, config, conditioning_dim: int): def __init__(self, config, conditioning_dim: int):
"""Initialize transformer for noise prediction. """Initialize transformer for noise prediction.
Args: Args:
config: Multi-Task DiTConfig with transformer parameters config: MultiTaskDiTConfig with transformer parameters
conditioning_dim: Dimension of concatenated observation features conditioning_dim: Dimension of concatenated observation features
""" """
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer_config = config.transformer
self.conditioning_dim = conditioning_dim self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0] self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon self.horizon = config.horizon
self.hidden_size = self.transformer_config.hidden_dim self.hidden_size = config.hidden_dim
self.num_layers = self.transformer_config.num_layers self.num_layers = config.num_layers
self.num_heads = self.transformer_config.num_heads self.num_heads = config.num_heads
self.dropout = self.transformer_config.dropout self.dropout = config.dropout
self.use_rope = self.transformer_config.use_rope self.use_rope = config.use_rope
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim), SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim), nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
@@ -347,7 +344,7 @@ class DiffusionTransformer(nn.Module):
# Project action dimensions to hidden size # Project action dimensions to hidden size
self.input_proj = nn.Linear(self.action_dim, self.hidden_size) self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if self.transformer_config.use_positional_encoding: if config.use_positional_encoding:
# Learnable positional embeddings for sequence positions (absolute encoding) # Learnable positional embeddings for sequence positions (absolute encoding)
self.pos_embedding = nn.Parameter( self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
@@ -363,8 +360,8 @@ class DiffusionTransformer(nn.Module):
num_features=self.cond_dim, num_features=self.cond_dim,
dropout=self.dropout, dropout=self.dropout,
use_rope=self.use_rope, use_rope=self.use_rope,
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences max_seq_len=self.horizon,
rope_base=getattr(self.transformer_config, "rope_base", 10000.0), rope_base=config.rope_base,
) )
for _ in range(self.num_layers) for _ in range(self.num_layers)
] ]
@@ -377,9 +374,7 @@ class DiffusionTransformer(nn.Module):
self._initialize_weights() self._initialize_weights()
def _initialize_weights(self): def _initialize_weights(self):
""" """Zero-initialize final linear layer of adaLN_modulation for training stability."""
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
"""
for block in self.transformer_blocks: for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
@@ -28,11 +28,7 @@ import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import ( from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
DiffusionConfig,
FlowMatchingConfig,
MultiTaskDiTConfig,
)
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed from lerobot.utils.random_utils import seeded_context, set_seed
@@ -108,13 +104,12 @@ def create_config(
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
horizon=horizon, horizon=horizon,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
# Use smaller model for faster tests
hidden_dim=128,
num_layers=2,
num_heads=4,
) )
# Use smaller model for faster tests
config.transformer.hidden_dim = 128
config.transformer.num_layers = 2
config.transformer.num_heads = 4
config.validate_features() config.validate_features()
return config return config
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
horizon = 16 horizon = 16
n_action_steps = 8 n_action_steps = 8
config = create_config( input_features = {
state_dim=state_dim, OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
action_dim=action_dim, f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
horizon=horizon, horizon=horizon,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
) # Use diffusion objective
config.objective = DiffusionConfig( objective="diffusion",
noise_scheduler_type="DDPM", noise_scheduler_type="DDPM",
num_train_timesteps=100, num_train_timesteps=100,
num_inference_steps=10, num_inference_steps=10,
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
) )
config.validate_features()
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.train() policy.train()
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
horizon = 16 horizon = 16
n_action_steps = 8 n_action_steps = 8
config = create_config( input_features = {
state_dim=state_dim, OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
action_dim=action_dim, f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
horizon=horizon, horizon=horizon,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
) # Use flow matching objective
config.objective = FlowMatchingConfig( objective="flow_matching",
sigma_min=0.0, sigma_min=0.0,
num_integration_steps=10, # Use fewer steps for faster tests num_integration_steps=10, # Fewer steps for faster tests
integration_method="euler", integration_method="euler",
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
) )
config.validate_features()
policy = MultiTaskDiTPolicy(config=config) policy = MultiTaskDiTPolicy(config=config)
policy.train() policy.train()
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
# Second group is vision encoder params with different lr # Second group is vision encoder params with different lr
assert "params" in param_groups[1] assert "params" in param_groups[1]
assert "lr" in param_groups[1] assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
assert param_groups[1]["lr"] == expected_lr assert param_groups[1]["lr"] == expected_lr