mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives
This commit is contained in:
@@ -0,0 +1 @@
|
||||
#!/usr/bin/env python
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
from typing import Literal
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
@@ -24,295 +23,84 @@ from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectiveConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ObjectiveConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig(ObjectiveConfig):
|
||||
"""Configuration for standard diffusion model training and inference.
|
||||
|
||||
These parameters control the noise scheduling and denoising process for
|
||||
standard DDPM/DDIM diffusion models.
|
||||
"""
|
||||
|
||||
objective_name: str = field(default="diffusion", init=False)
|
||||
|
||||
# Noise scheduler configuration - controls diffusion process
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
|
||||
beta_start: float = 0.0001 # Small initial noise level
|
||||
beta_end: float = 0.02 # Moderate final noise level
|
||||
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
|
||||
clip_sample: bool = True # Prevent extreme action values
|
||||
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
|
||||
|
||||
# Inference configuration
|
||||
num_inference_steps: int | None = None # Default to num_train_timesteps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate diffusion-specific parameters."""
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(
|
||||
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
|
||||
f"got {self.beta_start}, {self.beta_end}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimestepSamplingConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for timestep sampling strategies during training."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@TimestepSamplingConfig.register_subclass("uniform")
|
||||
@dataclass
|
||||
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
|
||||
"""Uniform timestep sampling from [0, 1]."""
|
||||
|
||||
strategy_name: str = field(default="uniform", init=False)
|
||||
|
||||
|
||||
@TimestepSamplingConfig.register_subclass("beta")
|
||||
@dataclass
|
||||
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
|
||||
"""Beta distribution timestep sampling.
|
||||
|
||||
Samples from Beta distribution emphasizing low timesteps (high noise).
|
||||
|
||||
This was inspired on the work from Physical Intelligence PI-0 model,
|
||||
where they suggested the beta distribution for sampling timesteps
|
||||
during training improved sample quality.
|
||||
"""
|
||||
|
||||
strategy_name: str = field(default="beta", init=False)
|
||||
|
||||
s: float = 0.999 # Max timestep threshold for beta sampling
|
||||
alpha: float = 1.5 # Beta distribution alpha parameter
|
||||
beta: float = 1.0 # Beta distribution beta parameter
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0.0 < self.s <= 1.0):
|
||||
raise ValueError(f"s must be in (0, 1], got {self.s}")
|
||||
|
||||
if self.alpha <= 0:
|
||||
raise ValueError(f"alpha must be positive, got {self.alpha}")
|
||||
|
||||
if self.beta <= 0:
|
||||
raise ValueError(f"beta must be positive, got {self.beta}")
|
||||
|
||||
|
||||
@ObjectiveConfig.register_subclass("flow_matching")
|
||||
@dataclass
|
||||
class FlowMatchingConfig(ObjectiveConfig):
|
||||
"""Configuration for flow matching training and inference.
|
||||
|
||||
These parameters control the velocity field learning and ODE integration
|
||||
process for flow matching models.
|
||||
"""
|
||||
|
||||
objective_name: str = field(default="flow_matching", init=False)
|
||||
|
||||
# Flow path construction
|
||||
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
|
||||
|
||||
# ODE integration for inference
|
||||
num_integration_steps: int = (
|
||||
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
|
||||
)
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
|
||||
# Timestep sampling strategy for training
|
||||
# Beta distribution found to be the most effective in practice, so it is the default
|
||||
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""Configuration for Transformer-based prediction model.
|
||||
|
||||
These parameters control the transformer architecture used for noise/velocity
|
||||
prediction in diffusion and flow matching models.
|
||||
"""
|
||||
|
||||
# Transformer architecture parameters
|
||||
hidden_dim: int = 512 # Hidden dimension of transformer
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = False # Whether to use absolute positional encoding
|
||||
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
||||
|
||||
# RoPE (Rotary Position Embedding) configuration
|
||||
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
|
||||
rope_base: float = 10000.0 # Base frequency for RoPE computation
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate Transformer-specific parameters."""
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
if self.diffusion_step_embed_dim <= 0:
|
||||
raise ValueError("diffusion_step_embed_dim must be positive")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionEncoderConfig:
|
||||
"""Configuration for CLIP vision encoder.
|
||||
|
||||
Uses CLIPVisionModel from transformers library.
|
||||
CLS token usage is handled automatically.
|
||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||
by setting resize_shape and crop_shape.
|
||||
|
||||
All image preprocessing is centralized here:
|
||||
1. Resize (optional) - resize images to target resolution
|
||||
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
||||
3. Random crop - whether to use random cropping during training
|
||||
|
||||
Any CLIP model from transformers can be used. Examples:
|
||||
- openai/clip-vit-base-patch16 (default, 768 dims)
|
||||
- openai/clip-vit-large-patch14 (1024 dims)
|
||||
- laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model)
|
||||
"""
|
||||
|
||||
model_name: str = "openai/clip-vit-base-patch16"
|
||||
use_separate_encoder_per_camera: bool = False
|
||||
|
||||
# Learning rate multiplier for vision encoder parameters
|
||||
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
||||
lr_multiplier: float = 0.1
|
||||
|
||||
# Image preprocessing (centralized)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
|
||||
crop_is_random: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate that model name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.model_name.lower():
|
||||
raise ValueError(
|
||||
f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'"
|
||||
)
|
||||
|
||||
if (
|
||||
self.resize_shape
|
||||
and self.crop_shape
|
||||
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
|
||||
):
|
||||
raise ValueError(
|
||||
f"crop_shape {self.crop_shape} must be smaller than or equal to "
|
||||
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextEncoderConfig:
|
||||
"""Configuration for CLIP text encoder.
|
||||
|
||||
Uses CLIP's text encoder to embed task descriptions, which are then
|
||||
used to condition the policy. The text embeddings are processed by
|
||||
a learnable projection layer before being concatenated into the
|
||||
conditioning vector.
|
||||
|
||||
Any HuggingFace CLIP model can be used. Examples:
|
||||
- openai/clip-vit-base-patch16 (default)
|
||||
- openai/clip-vit-large-patch14
|
||||
"""
|
||||
|
||||
model: str = "openai/clip-vit-base-patch16"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate that model name contains "clip" to ensure correct encoder type
|
||||
if "clip" not in self.model.lower():
|
||||
raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationEncoderConfig:
|
||||
"""Top-level configuration for observation encoding.
|
||||
|
||||
This config combines:
|
||||
- Vision encoding (required): CLIP vision encoder from transformers
|
||||
"""
|
||||
|
||||
vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
|
||||
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
@dataclass
|
||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
|
||||
A transformer-based policy that supports both diffusion and flow matching objectives
|
||||
for multi-task robot learning with text and vision conditioning.
|
||||
"""
|
||||
|
||||
# Temporal structure - controls how the policy processes time and predicts actions
|
||||
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t)
|
||||
horizon: int = 100 # predicted action steps into the future
|
||||
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
|
||||
n_obs_steps: int = 2 # Number of observation steps for temporal context
|
||||
horizon: int = 32 # Number of action steps to predict
|
||||
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||
|
||||
# Normalization strategy - critical for diffusion model performance
|
||||
# Objective Selection
|
||||
objective: Literal["diffusion", "flow_matching"] = "diffusion"
|
||||
|
||||
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # Number of diffusion timesteps
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
|
||||
beta_start: float = 0.0001 # Starting noise level
|
||||
beta_end: float = 0.02 # Ending noise level
|
||||
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
clip_sample: bool = True # Clip samples during denoising
|
||||
clip_sample_range: float = 1.0 # Clipping range [-x, x]
|
||||
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
|
||||
|
||||
# --- Flow Matching-specific (used when objective="flow_matching") ---
|
||||
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
timestep_sampling_strategy: Literal["uniform", "beta"] = "beta"
|
||||
|
||||
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
|
||||
|
||||
# Transformer Architecture
|
||||
hidden_dim: int = 512 # Transformer hidden dimension
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = False # Use absolute positional encoding
|
||||
timestep_embed_dim: int = 256 # Timestep embedding dimension
|
||||
use_rope: bool = True # Use Rotary Position Embedding
|
||||
rope_base: float = 10000.0 # RoPE base frequency
|
||||
|
||||
# Vision Encoder (CLIP)
|
||||
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
use_separate_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||
image_crop_is_random: bool = True # Random crop during training, center at inference
|
||||
|
||||
# Text Encoder (CLIP)
|
||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision
|
||||
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping
|
||||
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1
|
||||
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
|
||||
transformer: TransformerConfig = field(default_factory=TransformerConfig)
|
||||
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
|
||||
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
|
||||
|
||||
# training optimizer and scheduler hyperparameters
|
||||
# Training/Optimizer
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal
|
||||
optimizer_weight_decay: float = 0.0
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 0 # No warmup found to be optimal
|
||||
scheduler_warmup_steps: int = 0
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Auto-calculated
|
||||
drop_n_last_frames: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -320,11 +108,78 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
if self.drop_n_last_frames is None:
|
||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
"""Return Adam optimizer configuration optimized for diffusion training.
|
||||
self._validate()
|
||||
|
||||
Note: Vision encoder learning rate is set separately via get_optim_params.
|
||||
"""
|
||||
def _validate(self):
|
||||
"""Validate configuration parameters."""
|
||||
# Transformer validation
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
# Vision encoder validation
|
||||
if "clip" not in self.vision_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
|
||||
)
|
||||
if (
|
||||
self.image_resize_shape
|
||||
and self.image_crop_shape
|
||||
and (
|
||||
self.image_crop_shape[0] > self.image_resize_shape[0]
|
||||
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
|
||||
)
|
||||
|
||||
# Text encoder validation
|
||||
if "clip" not in self.text_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
|
||||
)
|
||||
|
||||
# Objective-specific validation
|
||||
if self.objective == "diffusion":
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
|
||||
|
||||
elif self.objective == "flow_matching":
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(
|
||||
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
|
||||
)
|
||||
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
|
||||
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
|
||||
if self.timestep_sampling_strategy == "beta":
|
||||
if not (0.0 < self.timestep_sampling_s <= 1.0):
|
||||
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
|
||||
if self.timestep_sampling_alpha <= 0:
|
||||
raise ValueError("timestep_sampling_alpha must be positive")
|
||||
if self.timestep_sampling_beta <= 0:
|
||||
raise ValueError("timestep_sampling_beta must be positive")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
@@ -333,7 +188,6 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
"""Return learning rate scheduler configuration."""
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
@@ -341,56 +195,41 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state
|
||||
# This allows for testing and simple state-only policies
|
||||
|
||||
# Validate crop shape fits within image dimensions
|
||||
crop_shape = self.observation_encoder.vision.crop_shape
|
||||
if crop_shape is not None:
|
||||
if self.image_crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.image_crop_shape[0] > image_ft.shape[1]
|
||||
or self.image_crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
|
||||
f"for '{key}'"
|
||||
)
|
||||
|
||||
# Ensure all images have same shape (current limitation)
|
||||
if len(self.image_features) > 0:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
first_key, first_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
if image_ft.shape != first_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
|
||||
)
|
||||
|
||||
@property
|
||||
def model_objective(self) -> str:
|
||||
return self.objective.objective_name
|
||||
|
||||
@property
|
||||
def is_diffusion(self) -> bool:
|
||||
return isinstance(self.objective, DiffusionConfig)
|
||||
return self.objective == "diffusion"
|
||||
|
||||
@property
|
||||
def is_flow_matching(self) -> bool:
|
||||
return isinstance(self.objective, FlowMatchingConfig)
|
||||
|
||||
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
|
||||
"""Get the objective-specific configuration with proper typing."""
|
||||
return self.objective
|
||||
return self.objective == "flow_matching"
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
"""Delta indices for stacking observations. Provides temporal context."""
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
"""Delta indices for action horizon prediction."""
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Indices for reward deltas (not used in diffusion policy)."""
|
||||
return None
|
||||
|
||||
@@ -52,23 +52,22 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
action_dim = config.action_feature.shape[0]
|
||||
horizon = config.horizon
|
||||
|
||||
self.model_objective = config.model_objective
|
||||
if config.is_diffusion:
|
||||
self.objective = DiffusionObjective(
|
||||
config.get_objective_config(),
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
elif config.is_flow_matching:
|
||||
self.objective = FlowMatchingObjective(
|
||||
config.get_objective_config(),
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_objective: {self.model_objective}")
|
||||
raise ValueError(f"Unsupported objective: {config.objective}")
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -90,7 +89,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
{"params": non_vision_params},
|
||||
{
|
||||
"params": vision_encoder_params,
|
||||
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier,
|
||||
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -118,8 +117,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
if self.config.observation_encoder.text:
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
# Always include task queue for text conditioning
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
||||
@@ -14,14 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module contains a base objective class, and implementation of the objective
|
||||
classes for use in the Multi-Task Diffusion Transformer Policy.
|
||||
"""Objective implementations for Multi-Task DiT policy.
|
||||
|
||||
Architecture:
|
||||
- BaseObjective: Abstract interface definition
|
||||
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
|
||||
- FlowMatchingObjective: Implements flow matching objective
|
||||
- DiffusionObjective: Standard DDPM/DDIM diffusion
|
||||
- FlowMatchingObjective: Flow matching with ODE integration
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -35,10 +31,7 @@ from torch import Tensor
|
||||
|
||||
|
||||
class BaseObjective(ABC):
|
||||
"""
|
||||
Base class for objectives used in Multi-Task DiT policy.
|
||||
Defines the interface for training loss computation and conditional sampling.
|
||||
"""
|
||||
"""Base class for objectives used in Multi-Task DiT policy."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int):
|
||||
self.config = config
|
||||
@@ -47,38 +40,17 @@ class BaseObjective(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
"""Compute training loss for the objective.
|
||||
|
||||
Args:
|
||||
model: The prediction network
|
||||
batch: Training batch with observations and actions
|
||||
conditioning_vec: Encoded observation features for conditioning
|
||||
|
||||
Returns:
|
||||
Scalar loss tensor
|
||||
"""
|
||||
"""Compute training loss."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
"""Generate action samples conditioned on embedded observation features.
|
||||
|
||||
Args:
|
||||
model: The prediction network
|
||||
batch_size: The number of samples to generate
|
||||
conditioning_vec: Encoded observation features for conditioning
|
||||
|
||||
Returns:
|
||||
Generated action sequences (batch_size, horizon, action_dim)
|
||||
"""
|
||||
"""Generate action samples conditioned on observations."""
|
||||
pass
|
||||
|
||||
|
||||
class DiffusionObjective(BaseObjective):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation.
|
||||
|
||||
Contains the noise scheduler, training loss, and conditional sampling.
|
||||
"""
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
@@ -90,8 +62,8 @@ class DiffusionObjective(BaseObjective):
|
||||
"beta_start": config.beta_start,
|
||||
"beta_end": config.beta_end,
|
||||
"beta_schedule": config.beta_schedule,
|
||||
"clip_sample": getattr(config, "clip_sample", True),
|
||||
"clip_sample_range": getattr(config, "clip_sample_range", 1.0),
|
||||
"clip_sample": config.clip_sample,
|
||||
"clip_sample_range": config.clip_sample_range,
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
@@ -105,7 +77,7 @@ class DiffusionObjective(BaseObjective):
|
||||
# Inference steps default to training steps if not provided
|
||||
self.num_inference_steps = (
|
||||
config.num_inference_steps
|
||||
if getattr(config, "num_inference_steps", None) is not None
|
||||
if config.num_inference_steps is not None
|
||||
else self.noise_scheduler.config.num_train_timesteps
|
||||
)
|
||||
|
||||
@@ -161,40 +133,29 @@ class DiffusionObjective(BaseObjective):
|
||||
|
||||
|
||||
class FlowMatchingObjective(BaseObjective):
|
||||
"""
|
||||
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
|
||||
noise to data. This basically interpolates as path between noise and a trained distribution
|
||||
with a set target velocity.
|
||||
"""
|
||||
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
"""Sample timesteps according to configured strategy.
|
||||
|
||||
Uniform: Sample t uniformly from [0,1]
|
||||
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
|
||||
"""
|
||||
if self.config.timestep_sampling.strategy_name == "uniform":
|
||||
"""Sample timesteps according to configured strategy."""
|
||||
if self.config.timestep_sampling_strategy == "uniform":
|
||||
return torch.rand(batch_size, device=device)
|
||||
elif self.config.timestep_sampling.strategy_name == "beta":
|
||||
elif self.config.timestep_sampling_strategy == "beta":
|
||||
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
|
||||
# This emphasizes t near 0 (high noise) when α > β
|
||||
beta_dist = torch.distributions.Beta(
|
||||
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta
|
||||
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
|
||||
)
|
||||
u = beta_dist.sample((batch_size,)).to(device)
|
||||
return self.config.timestep_sampling.s * (1.0 - u)
|
||||
return self.config.timestep_sampling_s * (1.0 - u)
|
||||
else:
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}")
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
"""Compute flow matching training loss.
|
||||
|
||||
Trains the model to predict the velocity field along linear interpolation paths.
|
||||
"""
|
||||
"""Compute flow matching training loss."""
|
||||
data = batch["action"] # Clean action sequences (B, T, D)
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
@@ -217,10 +178,7 @@ class FlowMatchingObjective(BaseObjective):
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
"""Generate actions by integrating the learned velocity field via ODE.
|
||||
|
||||
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
|
||||
"""
|
||||
"""Generate actions by integrating the learned velocity field via ODE."""
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
@@ -244,9 +202,7 @@ class FlowMatchingObjective(BaseObjective):
|
||||
def _euler_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
"""
|
||||
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
|
||||
"""
|
||||
"""Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)"""
|
||||
x = x_init
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
@@ -268,23 +224,10 @@ class FlowMatchingObjective(BaseObjective):
|
||||
def _rk4_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
"""4th-order Runge-Kutta integration.
|
||||
|
||||
Uses 4 velocity evaluations per step:
|
||||
k1 = v(x, t)
|
||||
k2 = v(x + dt·k1/2, t + dt/2)
|
||||
k3 = v(x + dt·k2/2, t + dt/2)
|
||||
k4 = v(x + dt·k3, t + dt)
|
||||
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
|
||||
|
||||
4x slower than Euler but more accurate.
|
||||
|
||||
Note: In practice, this seems to not matter much
|
||||
"""
|
||||
"""4th-order Runge-Kutta integration."""
|
||||
x = x_init
|
||||
|
||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||
"""dynamics helper to get velocity at (x, t)"""
|
||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||
with torch.no_grad():
|
||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
"""
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer."""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
vision_config = config.observation_encoder.vision
|
||||
|
||||
self._setup_preprocessing(vision_config)
|
||||
self._setup_preprocessing(config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys()) # Preserve ordering
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
if config.use_separate_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
|
||||
else:
|
||||
self.env_state_dim = 0
|
||||
|
||||
text_config = config.observation_encoder.text
|
||||
self.text_dim = config.transformer.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, vision_config):
|
||||
def _setup_preprocessing(self, config):
|
||||
"""Setup image preprocessing transforms."""
|
||||
if vision_config.resize_shape is not None:
|
||||
if config.image_resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=vision_config.resize_shape,
|
||||
size=config.image_resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
if vision_config.crop_shape is not None:
|
||||
|
||||
if config.image_crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
|
||||
if vision_config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||
if config.image_crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
# Vision features - get CLS token feature dimension
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||
|
||||
# Get output shape from encoder (deterministic for CLS tokens)
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
|
||||
# Shape is (B, N, C, H, W) - add time dimension
|
||||
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
||||
|
||||
vision_config = self.config.observation_encoder.vision
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
if self.config.use_separate_encoder_per_camera:
|
||||
# Process each camera with its own encoder
|
||||
camera_features = []
|
||||
|
||||
|
||||
@@ -308,32 +308,29 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
"""
|
||||
Transformer-based diffusion noise prediction model.
|
||||
"""
|
||||
"""Transformer-based diffusion noise prediction model."""
|
||||
|
||||
def __init__(self, config, conditioning_dim: int):
|
||||
"""Initialize transformer for noise prediction.
|
||||
|
||||
Args:
|
||||
config: Multi-Task DiTConfig with transformer parameters
|
||||
config: MultiTaskDiTConfig with transformer parameters
|
||||
conditioning_dim: Dimension of concatenated observation features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.transformer_config = config.transformer
|
||||
self.conditioning_dim = conditioning_dim
|
||||
|
||||
self.action_dim = config.action_feature.shape[0]
|
||||
self.horizon = config.horizon
|
||||
self.hidden_size = self.transformer_config.hidden_dim
|
||||
self.num_layers = self.transformer_config.num_layers
|
||||
self.num_heads = self.transformer_config.num_heads
|
||||
self.dropout = self.transformer_config.dropout
|
||||
self.use_rope = self.transformer_config.use_rope
|
||||
self.hidden_size = config.hidden_dim
|
||||
self.num_layers = config.num_layers
|
||||
self.num_heads = config.num_heads
|
||||
self.dropout = config.dropout
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
||||
self.timestep_embed_dim = config.timestep_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||
@@ -347,7 +344,7 @@ class DiffusionTransformer(nn.Module):
|
||||
# Project action dimensions to hidden size
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if self.transformer_config.use_positional_encoding:
|
||||
if config.use_positional_encoding:
|
||||
# Learnable positional embeddings for sequence positions (absolute encoding)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
@@ -363,8 +360,8 @@ class DiffusionTransformer(nn.Module):
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
use_rope=self.use_rope,
|
||||
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
|
||||
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
|
||||
max_seq_len=self.horizon,
|
||||
rope_base=config.rope_base,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
@@ -377,9 +374,7 @@ class DiffusionTransformer(nn.Module):
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
|
||||
"""
|
||||
"""Zero-initialize final linear layer of adaLN_modulation for training stability."""
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
@@ -28,11 +28,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
|
||||
DiffusionConfig,
|
||||
FlowMatchingConfig,
|
||||
MultiTaskDiTConfig,
|
||||
)
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
@@ -108,13 +104,12 @@ def create_config(
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
# Use smaller model for faster tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
|
||||
# Use smaller model for faster tests
|
||||
config.transformer.hidden_dim = 128
|
||||
config.transformer.num_layers = 2
|
||||
config.transformer.num_heads = 4
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
@@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = DiffusionConfig(
|
||||
# Use diffusion objective
|
||||
objective="diffusion",
|
||||
noise_scheduler_type="DDPM",
|
||||
num_train_timesteps=100,
|
||||
num_inference_steps=10,
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
@@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = FlowMatchingConfig(
|
||||
# Use flow matching objective
|
||||
objective="flow_matching",
|
||||
sigma_min=0.0,
|
||||
num_integration_steps=10, # Use fewer steps for faster tests
|
||||
num_integration_steps=10, # Fewer steps for faster tests
|
||||
integration_method="euler",
|
||||
# Smaller model for tests
|
||||
hidden_dim=128,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
@@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params():
|
||||
# Second group is vision encoder params with different lr
|
||||
assert "params" in param_groups[1]
|
||||
assert "lr" in param_groups[1]
|
||||
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
|
||||
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
|
||||
assert param_groups[1]["lr"] == expected_lr
|
||||
|
||||
Reference in New Issue
Block a user