diff --git a/src/lerobot/policies/multi_task_dit/__init__.py b/src/lerobot/policies/multi_task_dit/__init__.py new file mode 100644 index 000000000..4265cc3e6 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 2089b1372..dc7708551 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -15,8 +15,7 @@ # limitations under the License. from dataclasses import dataclass, field - -import draccus +from typing import Literal from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode @@ -24,295 +23,84 @@ from lerobot.optim.optimizers import AdamConfig from lerobot.optim.schedulers import DiffuserSchedulerConfig -@dataclass -class ObjectiveConfig(draccus.ChoiceRegistry): - """Base configuration for model objectives (diffusion, flow matching, etc.).""" - - pass - - -@ObjectiveConfig.register_subclass("diffusion") -@dataclass -class DiffusionConfig(ObjectiveConfig): - """Configuration for standard diffusion model training and inference. - - These parameters control the noise scheduling and denoising process for - standard DDPM/DDIM diffusion models. - """ - - objective_name: str = field(default="diffusion", init=False) - - # Noise scheduler configuration - controls diffusion process - noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM" - num_train_timesteps: int = 100 # 100 noise levels for fine-grained control - beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise - beta_start: float = 0.0001 # Small initial noise level - beta_end: float = 0.02 # Moderate final noise level - prediction_type: str = "epsilon" # Predict noise (works better than direct prediction) - clip_sample: bool = True # Prevent extreme action values - clip_sample_range: float = 1.0 # Clip to [-1, 1] range - - # Inference configuration - num_inference_steps: int | None = None # Default to num_train_timesteps - - def __post_init__(self): - """Validate diffusion-specific parameters.""" - if self.noise_scheduler_type not in ["DDPM", "DDIM"]: - raise ValueError( - f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}" - ) - - if self.prediction_type not in ["epsilon", "sample"]: - raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}") - - if self.num_train_timesteps <= 0: - raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}") - - if not (0.0 <= self.beta_start <= self.beta_end <= 1.0): - raise ValueError( - "beta values must satisfy 0 <= beta_start <= beta_end <= 1, " - f"got {self.beta_start}, {self.beta_end}" - ) - - -@dataclass -class TimestepSamplingConfig(draccus.ChoiceRegistry): - """Base configuration for timestep sampling strategies during training.""" - - pass - - -@TimestepSamplingConfig.register_subclass("uniform") -@dataclass -class UniformTimestepSamplingConfig(TimestepSamplingConfig): - """Uniform timestep sampling from [0, 1].""" - - strategy_name: str = field(default="uniform", init=False) - - -@TimestepSamplingConfig.register_subclass("beta") -@dataclass -class BetaTimestepSamplingConfig(TimestepSamplingConfig): - """Beta distribution timestep sampling. - - Samples from Beta distribution emphasizing low timesteps (high noise). - - This was inspired on the work from Physical Intelligence PI-0 model, - where they suggested the beta distribution for sampling timesteps - during training improved sample quality. - """ - - strategy_name: str = field(default="beta", init=False) - - s: float = 0.999 # Max timestep threshold for beta sampling - alpha: float = 1.5 # Beta distribution alpha parameter - beta: float = 1.0 # Beta distribution beta parameter - - def __post_init__(self): - if not (0.0 < self.s <= 1.0): - raise ValueError(f"s must be in (0, 1], got {self.s}") - - if self.alpha <= 0: - raise ValueError(f"alpha must be positive, got {self.alpha}") - - if self.beta <= 0: - raise ValueError(f"beta must be positive, got {self.beta}") - - -@ObjectiveConfig.register_subclass("flow_matching") -@dataclass -class FlowMatchingConfig(ObjectiveConfig): - """Configuration for flow matching training and inference. - - These parameters control the velocity field learning and ODE integration - process for flow matching models. - """ - - objective_name: str = field(default="flow_matching", init=False) - - # Flow path construction - sigma_min: float = 0.0 # Minimum noise level in flow interpolation path - - # ODE integration for inference - num_integration_steps: int = ( - 100 # Number of ODE integration steps (increased from 50 for smoother trajectories) - ) - integration_method: str = "euler" # ODE solver: "euler" or "rk4" - - # Timestep sampling strategy for training - # Beta distribution found to be the most effective in practice, so it is the default - timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig) - - def __post_init__(self): - if not (0.0 <= self.sigma_min <= 1.0): - raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}") - - if self.num_integration_steps <= 0: - raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}") - - if self.integration_method not in ["euler", "rk4"]: - raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}") - - -@dataclass -class TransformerConfig: - """Configuration for Transformer-based prediction model. - - These parameters control the transformer architecture used for noise/velocity - prediction in diffusion and flow matching models. - """ - - # Transformer architecture parameters - hidden_dim: int = 512 # Hidden dimension of transformer - num_layers: int = 6 # Number of transformer layers - num_heads: int = 8 # Number of attention heads - dropout: float = 0.1 # Dropout rate - use_positional_encoding: bool = False # Whether to use absolute positional encoding - diffusion_step_embed_dim: int = 256 # Timestep embedding size - - # RoPE (Rotary Position Embedding) configuration - use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True) - rope_base: float = 10000.0 # Base frequency for RoPE computation - - def __post_init__(self): - """Validate Transformer-specific parameters.""" - if self.hidden_dim <= 0: - raise ValueError("hidden_dim must be positive") - - if self.num_layers <= 0: - raise ValueError("num_layers must be positive") - - if self.num_heads <= 0: - raise ValueError("num_heads must be positive") - - if self.hidden_dim % self.num_heads != 0: - raise ValueError("hidden_dim must be divisible by num_heads") - - if not (0.0 <= self.dropout <= 1.0): - raise ValueError("dropout must be between 0.0 and 1.0") - - if self.diffusion_step_embed_dim <= 0: - raise ValueError("diffusion_step_embed_dim must be positive") - - -@dataclass -class VisionEncoderConfig: - """Configuration for CLIP vision encoder. - - Uses CLIPVisionModel from transformers library. - CLS token usage is handled automatically. - CLIP's internal preprocessing (resize to 224x224) can be overridden - by setting resize_shape and crop_shape. - - All image preprocessing is centralized here: - 1. Resize (optional) - resize images to target resolution - 2. Crop (optional) - crop after resize, must be smaller than resize_shape - 3. Random crop - whether to use random cropping during training - - Any CLIP model from transformers can be used. Examples: - - openai/clip-vit-base-patch16 (default, 768 dims) - - openai/clip-vit-large-patch14 (1024 dims) - - laion/CLIP-ViT-B-32-xlaai256 (alternative CLIP model) - """ - - model_name: str = "openai/clip-vit-base-patch16" - use_separate_encoder_per_camera: bool = False - - # Learning rate multiplier for vision encoder parameters - # Vision encoder learning rate = optimizer_lr * lr_multiplier - lr_multiplier: float = 0.1 - - # Image preprocessing (centralized) - resize_shape: tuple[int, int] | None = None - crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP - crop_is_random: bool = True - - def __post_init__(self): - # Validate that model name contains "clip" to ensure correct encoder type - if "clip" not in self.model_name.lower(): - raise ValueError( - f"model_name must be a CLIP model from transformers (contain 'clip'), got '{self.model_name}'" - ) - - if ( - self.resize_shape - and self.crop_shape - and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1]) - ): - raise ValueError( - f"crop_shape {self.crop_shape} must be smaller than or equal to " - f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}" - ) - - -@dataclass -class TextEncoderConfig: - """Configuration for CLIP text encoder. - - Uses CLIP's text encoder to embed task descriptions, which are then - used to condition the policy. The text embeddings are processed by - a learnable projection layer before being concatenated into the - conditioning vector. - - Any HuggingFace CLIP model can be used. Examples: - - openai/clip-vit-base-patch16 (default) - - openai/clip-vit-large-patch14 - """ - - model: str = "openai/clip-vit-base-patch16" - - def __post_init__(self): - # Validate that model name contains "clip" to ensure correct encoder type - if "clip" not in self.model.lower(): - raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'") - - -@dataclass -class ObservationEncoderConfig: - """Top-level configuration for observation encoding. - - This config combines: - - Vision encoding (required): CLIP vision encoder from transformers - """ - - vision: VisionEncoderConfig = field(default_factory=VisionEncoderConfig) - text: TextEncoderConfig = field(default_factory=TextEncoderConfig) - - @PreTrainedConfig.register_subclass("multi_task_dit") @dataclass class MultiTaskDiTConfig(PreTrainedConfig): - """ - Configuration class for the Multi-Task Diffusion Transformer (DiT) policy. + """Configuration for the Multi-Task Diffusion Transformer (DiT) policy. + + A transformer-based policy that supports both diffusion and flow matching objectives + for multi-task robot learning with text and vision conditioning. """ - # Temporal structure - controls how the policy processes time and predicts actions - n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t) - horizon: int = 100 # predicted action steps into the future - n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start + n_obs_steps: int = 2 # Number of observation steps for temporal context + horizon: int = 32 # Number of action steps to predict + n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz) - # Normalization strategy - critical for diffusion model performance + # Objective Selection + objective: Literal["diffusion", "flow_matching"] = "diffusion" + + # --- Diffusion-specific (used when objective="diffusion") --- + noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM" + num_train_timesteps: int = 100 # Number of diffusion timesteps + beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type + beta_start: float = 0.0001 # Starting noise level + beta_end: float = 0.02 # Ending noise level + prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean) + clip_sample: bool = True # Clip samples during denoising + clip_sample_range: float = 1.0 # Clipping range [-x, x] + num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps) + + # --- Flow Matching-specific (used when objective="flow_matching") --- + sigma_min: float = 0.0 # Minimum noise in flow interpolation path + num_integration_steps: int = 100 # ODE integration steps at inference + integration_method: str = "euler" # ODE solver: "euler" or "rk4" + timestep_sampling_strategy: Literal["uniform", "beta"] = "beta" + + timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold + timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha + timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta + + # Transformer Architecture + hidden_dim: int = 512 # Transformer hidden dimension + num_layers: int = 6 # Number of transformer layers + num_heads: int = 8 # Number of attention heads + dropout: float = 0.1 # Dropout rate + use_positional_encoding: bool = False # Use absolute positional encoding + timestep_embed_dim: int = 256 # Timestep embedding dimension + use_rope: bool = True # Use Rotary Position Embedding + rope_base: float = 10000.0 # RoPE base frequency + + # Vision Encoder (CLIP) + vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + use_separate_encoder_per_camera: bool = False # Separate encoder per camera view + vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder + image_resize_shape: tuple[int, int] | None = None # Resize images before crop + image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default) + image_crop_is_random: bool = True # Random crop during training, center at inference + + # Text Encoder (CLIP) + text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + + # Normalization normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { - "VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision - "STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping - "ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, } ) - drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1 - observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig) - transformer: TransformerConfig = field(default_factory=TransformerConfig) - objective: ObjectiveConfig = field(default_factory=DiffusionConfig) - do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation - - # training optimizer and scheduler hyperparameters + # Training/Optimizer optimizer_lr: float = 2e-5 optimizer_betas: tuple = (0.95, 0.999) optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal + optimizer_weight_decay: float = 0.0 scheduler_name: str = "cosine" - scheduler_warmup_steps: int = 0 # No warmup found to be optimal + scheduler_warmup_steps: int = 0 + do_mask_loss_for_padding: bool = False + + # Auto-calculated + drop_n_last_frames: int | None = None def __post_init__(self): super().__post_init__() @@ -320,11 +108,78 @@ class MultiTaskDiTConfig(PreTrainedConfig): if self.drop_n_last_frames is None: self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1 - def get_optimizer_preset(self) -> AdamConfig: - """Return Adam optimizer configuration optimized for diffusion training. + self._validate() - Note: Vision encoder learning rate is set separately via get_optim_params. - """ + def _validate(self): + """Validate configuration parameters.""" + # Transformer validation + if self.hidden_dim <= 0: + raise ValueError("hidden_dim must be positive") + if self.num_layers <= 0: + raise ValueError("num_layers must be positive") + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + if self.hidden_dim % self.num_heads != 0: + raise ValueError("hidden_dim must be divisible by num_heads") + if not (0.0 <= self.dropout <= 1.0): + raise ValueError("dropout must be between 0.0 and 1.0") + + # Vision encoder validation + if "clip" not in self.vision_encoder_name.lower(): + raise ValueError( + f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'" + ) + if ( + self.image_resize_shape + and self.image_crop_shape + and ( + self.image_crop_shape[0] > self.image_resize_shape[0] + or self.image_crop_shape[1] > self.image_resize_shape[1] + ) + ): + raise ValueError( + f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}" + ) + + # Text encoder validation + if "clip" not in self.text_encoder_name.lower(): + raise ValueError( + f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'" + ) + + # Objective-specific validation + if self.objective == "diffusion": + if self.noise_scheduler_type not in ["DDPM", "DDIM"]: + raise ValueError( + f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}" + ) + if self.prediction_type not in ["epsilon", "sample"]: + raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}") + if self.num_train_timesteps <= 0: + raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}") + if not (0.0 <= self.beta_start <= self.beta_end <= 1.0): + raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}") + + elif self.objective == "flow_matching": + if not (0.0 <= self.sigma_min <= 1.0): + raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}") + if self.num_integration_steps <= 0: + raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}") + if self.integration_method not in ["euler", "rk4"]: + raise ValueError( + f"integration_method must be 'euler' or 'rk4', got {self.integration_method}" + ) + if self.timestep_sampling_strategy not in ["uniform", "beta"]: + raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'") + if self.timestep_sampling_strategy == "beta": + if not (0.0 < self.timestep_sampling_s <= 1.0): + raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}") + if self.timestep_sampling_alpha <= 0: + raise ValueError("timestep_sampling_alpha must be positive") + if self.timestep_sampling_beta <= 0: + raise ValueError("timestep_sampling_beta must be positive") + + def get_optimizer_preset(self) -> AdamConfig: return AdamConfig( lr=self.optimizer_lr, betas=self.optimizer_betas, @@ -333,7 +188,6 @@ class MultiTaskDiTConfig(PreTrainedConfig): ) def get_scheduler_preset(self) -> DiffuserSchedulerConfig: - """Return learning rate scheduler configuration.""" return DiffuserSchedulerConfig( name=self.scheduler_name, num_warmup_steps=self.scheduler_warmup_steps, @@ -341,56 +195,41 @@ class MultiTaskDiTConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate that required input features are present and properly configured.""" - # Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state - # This allows for testing and simple state-only policies - - # Validate crop shape fits within image dimensions - crop_shape = self.observation_encoder.vision.crop_shape - if crop_shape is not None: + if self.image_crop_shape is not None: for key, image_ft in self.image_features.items(): - if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]: + if ( + self.image_crop_shape[0] > image_ft.shape[1] + or self.image_crop_shape[1] > image_ft.shape[2] + ): raise ValueError( - f"`crop_shape` should fit within the images shapes. Got {crop_shape} " - f"for `crop_shape` and {image_ft.shape} for " - f"`{key}`." + f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} " + f"for '{key}'" ) - # Ensure all images have same shape (current limitation) if len(self.image_features) > 0: - first_image_key, first_image_ft = next(iter(self.image_features.items())) + first_key, first_ft = next(iter(self.image_features.items())) for key, image_ft in self.image_features.items(): - if image_ft.shape != first_image_ft.shape: + if image_ft.shape != first_ft.shape: raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}" ) - @property - def model_objective(self) -> str: - return self.objective.objective_name - @property def is_diffusion(self) -> bool: - return isinstance(self.objective, DiffusionConfig) + return self.objective == "diffusion" @property def is_flow_matching(self) -> bool: - return isinstance(self.objective, FlowMatchingConfig) - - def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig: - """Get the objective-specific configuration with proper typing.""" - return self.objective + return self.objective == "flow_matching" @property def observation_delta_indices(self) -> list: - """Delta indices for stacking observations. Provides temporal context.""" return list(range(1 - self.n_obs_steps, 1)) @property def action_delta_indices(self) -> list: - """Delta indices for action horizon prediction.""" return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) @property def reward_delta_indices(self) -> None: - """Indices for reward deltas (not used in diffusion policy).""" return None diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index e8b69c949..77adb5d73 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -52,23 +52,22 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): action_dim = config.action_feature.shape[0] horizon = config.horizon - self.model_objective = config.model_objective if config.is_diffusion: self.objective = DiffusionObjective( - config.get_objective_config(), + config, action_dim=action_dim, horizon=horizon, do_mask_loss_for_padding=config.do_mask_loss_for_padding, ) elif config.is_flow_matching: self.objective = FlowMatchingObjective( - config.get_objective_config(), + config, action_dim=action_dim, horizon=horizon, do_mask_loss_for_padding=config.do_mask_loss_for_padding, ) else: - raise ValueError(f"Unsupported model_objective: {self.model_objective}") + raise ValueError(f"Unsupported objective: {config.objective}") self.reset() @@ -90,7 +89,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): {"params": non_vision_params}, { "params": vision_encoder_params, - "lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier, + "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier, }, ] @@ -118,8 +117,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): if self.config.env_state_feature: self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) - if self.config.observation_encoder.text: - self._queues["task"] = deque(maxlen=self.config.n_obs_steps) + # Always include task queue for text conditioning + self._queues["task"] = deque(maxlen=self.config.n_obs_steps) def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: """Run the batch through the model and compute the loss for training or validation.""" diff --git a/src/lerobot/policies/multi_task_dit/modules/objectives.py b/src/lerobot/policies/multi_task_dit/modules/objectives.py index 32dcd592e..56ad5de05 100644 --- a/src/lerobot/policies/multi_task_dit/modules/objectives.py +++ b/src/lerobot/policies/multi_task_dit/modules/objectives.py @@ -14,14 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -This module contains a base objective class, and implementation of the objective -classes for use in the Multi-Task Diffusion Transformer Policy. +"""Objective implementations for Multi-Task DiT policy. -Architecture: -- BaseObjective: Abstract interface definition -- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective -- FlowMatchingObjective: Implements flow matching objective +- DiffusionObjective: Standard DDPM/DDIM diffusion +- FlowMatchingObjective: Flow matching with ODE integration """ from abc import ABC, abstractmethod @@ -35,10 +31,7 @@ from torch import Tensor class BaseObjective(ABC): - """ - Base class for objectives used in Multi-Task DiT policy. - Defines the interface for training loss computation and conditional sampling. - """ + """Base class for objectives used in Multi-Task DiT policy.""" def __init__(self, config, action_dim: int, horizon: int): self.config = config @@ -47,38 +40,17 @@ class BaseObjective(ABC): @abstractmethod def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: - """Compute training loss for the objective. - - Args: - model: The prediction network - batch: Training batch with observations and actions - conditioning_vec: Encoded observation features for conditioning - - Returns: - Scalar loss tensor - """ + """Compute training loss.""" pass @abstractmethod def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: - """Generate action samples conditioned on embedded observation features. - - Args: - model: The prediction network - batch_size: The number of samples to generate - conditioning_vec: Encoded observation features for conditioning - - Returns: - Generated action sequences (batch_size, horizon, action_dim) - """ + """Generate action samples conditioned on observations.""" pass class DiffusionObjective(BaseObjective): - """Standard diffusion (DDPM/DDIM) objective implementation. - - Contains the noise scheduler, training loss, and conditional sampling. - """ + """Standard diffusion (DDPM/DDIM) objective implementation.""" def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): super().__init__(config, action_dim, horizon) @@ -90,8 +62,8 @@ class DiffusionObjective(BaseObjective): "beta_start": config.beta_start, "beta_end": config.beta_end, "beta_schedule": config.beta_schedule, - "clip_sample": getattr(config, "clip_sample", True), - "clip_sample_range": getattr(config, "clip_sample_range", 1.0), + "clip_sample": config.clip_sample, + "clip_sample_range": config.clip_sample_range, "prediction_type": config.prediction_type, } @@ -105,7 +77,7 @@ class DiffusionObjective(BaseObjective): # Inference steps default to training steps if not provided self.num_inference_steps = ( config.num_inference_steps - if getattr(config, "num_inference_steps", None) is not None + if config.num_inference_steps is not None else self.noise_scheduler.config.num_train_timesteps ) @@ -161,40 +133,29 @@ class DiffusionObjective(BaseObjective): class FlowMatchingObjective(BaseObjective): - """ - Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports - noise to data. This basically interpolates as path between noise and a trained distribution - with a set target velocity. - """ + """Flow matching objective: trains a model to predict velocity fields.""" def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): super().__init__(config, action_dim, horizon) self.do_mask_loss_for_padding = do_mask_loss_for_padding def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: - """Sample timesteps according to configured strategy. - - Uniform: Sample t uniformly from [0,1] - Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t) - """ - if self.config.timestep_sampling.strategy_name == "uniform": + """Sample timesteps according to configured strategy.""" + if self.config.timestep_sampling_strategy == "uniform": return torch.rand(batch_size, device=device) - elif self.config.timestep_sampling.strategy_name == "beta": + elif self.config.timestep_sampling_strategy == "beta": # Sample u ~ Beta(α, β) then transform: t = s(1-u) # This emphasizes t near 0 (high noise) when α > β beta_dist = torch.distributions.Beta( - self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta + self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta ) u = beta_dist.sample((batch_size,)).to(device) - return self.config.timestep_sampling.s * (1.0 - u) + return self.config.timestep_sampling_s * (1.0 - u) else: - raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}") + raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}") def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: - """Compute flow matching training loss. - - Trains the model to predict the velocity field along linear interpolation paths. - """ + """Compute flow matching training loss.""" data = batch["action"] # Clean action sequences (B, T, D) batch_size = data.shape[0] device = data.device @@ -217,10 +178,7 @@ class FlowMatchingObjective(BaseObjective): return loss.mean() def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: - """Generate actions by integrating the learned velocity field via ODE. - - Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data) - """ + """Generate actions by integrating the learned velocity field via ODE.""" device = next(model.parameters()).device dtype = next(model.parameters()).dtype @@ -244,9 +202,7 @@ class FlowMatchingObjective(BaseObjective): def _euler_integrate( self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor ) -> Tensor: - """ - Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n) - """ + """Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)""" x = x_init for i in range(len(time_grid) - 1): @@ -268,23 +224,10 @@ class FlowMatchingObjective(BaseObjective): def _rk4_integrate( self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor ) -> Tensor: - """4th-order Runge-Kutta integration. - - Uses 4 velocity evaluations per step: - k1 = v(x, t) - k2 = v(x + dt·k1/2, t + dt/2) - k3 = v(x + dt·k2/2, t + dt/2) - k4 = v(x + dt·k3, t + dt) - x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4) - - 4x slower than Euler but more accurate. - - Note: In practice, this seems to not matter much - """ + """4th-order Runge-Kutta integration.""" x = x_init def dynamics(x_val: Tensor, t_scalar: float) -> Tensor: - """dynamics helper to get velocity at (x, t)""" t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device) with torch.no_grad(): return model(x_val, t_batch, conditioning_vec=conditioning_vec) diff --git a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py index 794ab71f1..b2f353282 100644 --- a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py +++ b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py @@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module): class CLIPTextEncoder(nn.Module): - """Supports any HuggingFace CLIP model. The encoder weights are frozen, - and a learnable projection layer maps the CLIP embeddings to the desired dimension. - """ + """CLIP text encoder with frozen weights and a learnable projection layer.""" def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): super().__init__() @@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - vision_config = config.observation_encoder.vision - self._setup_preprocessing(vision_config) + self._setup_preprocessing(config) if config.image_features: self.num_cameras = len(config.image_features) - self.camera_names = list(config.image_features.keys()) # Preserve ordering + self.camera_names = list(config.image_features.keys()) - if vision_config.use_separate_encoder_per_camera: + if config.use_separate_encoder_per_camera: self.vision_encoders = nn.ModuleList( - [CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names] + [CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names] ) self.vision_encoder = None else: - self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name) + self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name) self.vision_encoders = None else: self.vision_encoder = None @@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module): else: self.env_state_dim = 0 - text_config = config.observation_encoder.text - self.text_dim = config.transformer.hidden_dim - self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim) + self.text_dim = config.hidden_dim + self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim) self._setup_vector_output() @@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module): return images - def _setup_preprocessing(self, vision_config): + def _setup_preprocessing(self, config): """Setup image preprocessing transforms.""" - if vision_config.resize_shape is not None: + if config.image_resize_shape is not None: self.do_resize = True self.resize = torchvision.transforms.Resize( - size=vision_config.resize_shape, + size=config.image_resize_shape, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True, ) else: self.do_resize = False - if vision_config.crop_shape is not None: + + if config.image_crop_shape is not None: self.do_crop = True - self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape) - if vision_config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape) + if config.image_crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape) else: self.maybe_random_crop = self.center_crop else: @@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module): # Vision features - get CLS token feature dimension if self.vision_encoder is not None or self.vision_encoders is not None: - encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values())) + encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders)) # Get output shape from encoder (deterministic for CLS tokens) feature_map_shape = encoder_to_check.get_output_shape() @@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module): # Shape is (B, N, C, H, W) - add time dimension images = images.unsqueeze(1) # (B, 1, N, C, H, W) - vision_config = self.config.observation_encoder.vision - if vision_config.use_separate_encoder_per_camera: + if self.config.use_separate_encoder_per_camera: # Process each camera with its own encoder camera_features = [] diff --git a/src/lerobot/policies/multi_task_dit/modules/transformer.py b/src/lerobot/policies/multi_task_dit/modules/transformer.py index 7631d07d6..fca1f9022 100644 --- a/src/lerobot/policies/multi_task_dit/modules/transformer.py +++ b/src/lerobot/policies/multi_task_dit/modules/transformer.py @@ -308,32 +308,29 @@ class TransformerBlock(nn.Module): class DiffusionTransformer(nn.Module): - """ - Transformer-based diffusion noise prediction model. - """ + """Transformer-based diffusion noise prediction model.""" def __init__(self, config, conditioning_dim: int): """Initialize transformer for noise prediction. Args: - config: Multi-Task DiTConfig with transformer parameters + config: MultiTaskDiTConfig with transformer parameters conditioning_dim: Dimension of concatenated observation features """ super().__init__() self.config = config - self.transformer_config = config.transformer self.conditioning_dim = conditioning_dim self.action_dim = config.action_feature.shape[0] self.horizon = config.horizon - self.hidden_size = self.transformer_config.hidden_dim - self.num_layers = self.transformer_config.num_layers - self.num_heads = self.transformer_config.num_heads - self.dropout = self.transformer_config.dropout - self.use_rope = self.transformer_config.use_rope + self.hidden_size = config.hidden_dim + self.num_layers = config.num_layers + self.num_heads = config.num_heads + self.dropout = config.dropout + self.use_rope = config.use_rope - self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim + self.timestep_embed_dim = config.timestep_embed_dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(self.timestep_embed_dim), nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim), @@ -347,7 +344,7 @@ class DiffusionTransformer(nn.Module): # Project action dimensions to hidden size self.input_proj = nn.Linear(self.action_dim, self.hidden_size) - if self.transformer_config.use_positional_encoding: + if config.use_positional_encoding: # Learnable positional embeddings for sequence positions (absolute encoding) self.pos_embedding = nn.Parameter( torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) @@ -363,8 +360,8 @@ class DiffusionTransformer(nn.Module): num_features=self.cond_dim, dropout=self.dropout, use_rope=self.use_rope, - max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences - rope_base=getattr(self.transformer_config, "rope_base", 10000.0), + max_seq_len=self.horizon, + rope_base=config.rope_base, ) for _ in range(self.num_layers) ] @@ -377,9 +374,7 @@ class DiffusionTransformer(nn.Module): self._initialize_weights() def _initialize_weights(self): - """ - Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability - """ + """Zero-initialize final linear layer of adaLN_modulation for training stability.""" for block in self.transformer_blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py index 575a7a9ca..88bdd6e24 100644 --- a/tests/policies/multi_task_dit/test_multi_task_dit.py +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -28,11 +28,7 @@ import torch from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.multi_task_dit.configuration_multi_task_dit import ( - DiffusionConfig, - FlowMatchingConfig, - MultiTaskDiTConfig, -) +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed @@ -108,13 +104,12 @@ def create_config( n_obs_steps=n_obs_steps, horizon=horizon, n_action_steps=n_action_steps, + # Use smaller model for faster tests + hidden_dim=128, + num_layers=2, + num_heads=4, ) - # Use smaller model for faster tests - config.transformer.hidden_dim = 128 - config.transformer.num_layers = 2 - config.transformer.num_heads = 4 - config.validate_features() return config @@ -189,18 +184,28 @@ def test_multi_task_dit_policy_diffusion_objective(): horizon = 16 n_action_steps = 8 - config = create_config( - state_dim=state_dim, - action_dim=action_dim, + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, n_obs_steps=n_obs_steps, horizon=horizon, n_action_steps=n_action_steps, - ) - config.objective = DiffusionConfig( + # Use diffusion objective + objective="diffusion", noise_scheduler_type="DDPM", num_train_timesteps=100, num_inference_steps=10, + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, ) + config.validate_features() policy = MultiTaskDiTPolicy(config=config) policy.train() @@ -235,18 +240,28 @@ def test_multi_task_dit_policy_flow_matching_objective(): horizon = 16 n_action_steps = 8 - config = create_config( - state_dim=state_dim, - action_dim=action_dim, + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, n_obs_steps=n_obs_steps, horizon=horizon, n_action_steps=n_action_steps, - ) - config.objective = FlowMatchingConfig( + # Use flow matching objective + objective="flow_matching", sigma_min=0.0, - num_integration_steps=10, # Use fewer steps for faster tests + num_integration_steps=10, # Fewer steps for faster tests integration_method="euler", + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, ) + config.validate_features() policy = MultiTaskDiTPolicy(config=config) policy.train() @@ -373,5 +388,5 @@ def test_multi_task_dit_policy_get_optim_params(): # Second group is vision encoder params with different lr assert "params" in param_groups[1] assert "lr" in param_groups[1] - expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier + expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier assert param_groups[1]["lr"] == expected_lr