diff --git a/pyproject.toml b/pyproject.toml index e953820ff..7e22285e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] # Policies pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"] groot = [ "lerobot[transformers-dep]", "peft>=0.13.0,<1.0.0", diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index eec574296..63d6a44f4 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -157,7 +157,7 @@ available_datasets = sorted( ) # lists all available policies from `lerobot/policies` -available_policies = ["act", "diffusion", "tdmpc", "vqbet"] +available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"] # lists all available robots from `lerobot/robots` available_robots = [ diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 4cdc89ea9..c8240e3a4 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig @@ -25,6 +26,7 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig __all__ = [ "ACTConfig", "DiffusionConfig", + "MultiTaskDiTConfig", "PI0Config", "PI05Config", "SmolVLAConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index eb6266757..65682a98f 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.groot.configuration_groot import GrootConfig +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy @@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". + "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". Returns: The policy class corresponding to the given name. @@ -79,6 +80,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.act.modeling_act import ACTPolicy return ACTPolicy + elif name == "multi_task_dit": + from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy + + return MultiTaskDiTPolicy elif name == "vqbet": from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy @@ -120,8 +125,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier". + "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", + "reward_classifier", "smolvla". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -136,6 +141,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return DiffusionConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) + elif policy_type == "multi_task_dit": + return MultiTaskDiTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) elif policy_type == "pi0": @@ -274,6 +281,16 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, MultiTaskDiTConfig): + from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, + ) + + processors = make_multi_task_dit_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, VQBeTConfig): from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors diff --git a/src/lerobot/policies/multi_task_dit/README.md b/src/lerobot/policies/multi_task_dit/README.md new file mode 100644 index 000000000..293da4e54 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/README.md @@ -0,0 +1,45 @@ +# Multi-Task DiT Policy + +For details describing the architecture, see the citations and the blog post from Bryson Jones: https://brysonkjones.substack.com/p/dissecting-multi-task-diffusion-transformer-policy + +## Trainining and Inference Baseline Recommendations: + +### Training + +- Number of demonstrations: >100 per task +- Batch Size: 320 +- Objective: Diffusion +- Cameras: At least two, with one egocentric view per arm + +### Inference + +- GPU: 5070 Ti or above in performance +- Sampling: + - Strategy: DDIM + - Number of Timesteps: 10 + +## Citation + +If you use this work, please cite the following works: + +```bibtex +@misc{jones2025multitaskditpolicy, + author = {Bryson Jones}, + title = {Dissecting Multitask Diffusion Transformer Policy}, + year = {2025}, + url = {https://brysonkjones.substack.com/p/dissecting-multitask-diffusion-transformer-policy}, + note = {Blog post} +} +``` + +```bibtex +@misc{trilbmteam2025carefulexaminationlargebehaviormodels, + author = {TRI LBM Team}, + title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation}, + year = {2025}, + eprint = {arXiv:2507.05331}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2507.05331} +} +``` diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py new file mode 100644 index 000000000..5a6a71f37 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +import draccus + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamConfig +from lerobot.optim.schedulers import DiffuserSchedulerConfig + + +@dataclass +class ObjectiveConfig(draccus.ChoiceRegistry): + """Base configuration for model objectives (diffusion, flow matching, etc.).""" + + pass + + +@ObjectiveConfig.register_subclass("diffusion") +@dataclass +class DiffusionConfig(ObjectiveConfig): + """Configuration for standard diffusion model training and inference. + + These parameters control the noise scheduling and denoising process for + standard DDPM/DDIM diffusion models. + """ + + objective_name: str = field(default="diffusion", init=False) + + # Noise scheduler configuration - controls diffusion process + noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM" + num_train_timesteps: int = 100 # 100 noise levels for fine-grained control + beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise + beta_start: float = 0.0001 # Small initial noise level + beta_end: float = 0.02 # Moderate final noise level + prediction_type: str = "epsilon" # Predict noise (works better than direct prediction) + clip_sample: bool = True # Prevent extreme action values + clip_sample_range: float = 1.0 # Clip to [-1, 1] range + + # Inference configuration + num_inference_steps: int | None = None # Default to num_train_timesteps + + def __post_init__(self): + """Validate diffusion-specific parameters.""" + if self.noise_scheduler_type not in ["DDPM", "DDIM"]: + raise ValueError( + f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}" + ) + + if self.prediction_type not in ["epsilon", "sample"]: + raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}") + + if self.num_train_timesteps <= 0: + raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}") + + if not (0.0 <= self.beta_start <= self.beta_end <= 1.0): + raise ValueError( + "beta values must satisfy 0 <= beta_start <= beta_end <= 1, " + f"got {self.beta_start}, {self.beta_end}" + ) + + +@dataclass +class TimestepSamplingConfig(draccus.ChoiceRegistry): + """Base configuration for timestep sampling strategies during training.""" + + pass + + +@TimestepSamplingConfig.register_subclass("uniform") +@dataclass +class UniformTimestepSamplingConfig(TimestepSamplingConfig): + """Uniform timestep sampling from [0, 1].""" + + strategy_name: str = field(default="uniform", init=False) + + +@TimestepSamplingConfig.register_subclass("beta") +@dataclass +class BetaTimestepSamplingConfig(TimestepSamplingConfig): + """Beta distribution timestep sampling. + + Samples from Beta distribution emphasizing low timesteps (high noise). + + This was inspired on the work from Physical Intelligence PI-0 model, + where they suggested the beta distribution for sampling timesteps + during training improved sample quality. + """ + + strategy_name: str = field(default="beta", init=False) + + s: float = 0.999 # Max timestep threshold for beta sampling + alpha: float = 1.5 # Beta distribution alpha parameter + beta: float = 1.0 # Beta distribution beta parameter + + def __post_init__(self): + if not (0.0 < self.s <= 1.0): + raise ValueError(f"s must be in (0, 1], got {self.s}") + + if self.alpha <= 0: + raise ValueError(f"alpha must be positive, got {self.alpha}") + + if self.beta <= 0: + raise ValueError(f"beta must be positive, got {self.beta}") + + +@ObjectiveConfig.register_subclass("flow_matching") +@dataclass +class FlowMatchingConfig(ObjectiveConfig): + """Configuration for flow matching training and inference. + + These parameters control the velocity field learning and ODE integration + process for flow matching models. + """ + + objective_name: str = field(default="flow_matching", init=False) + + # Flow path construction + sigma_min: float = 0.0 # Minimum noise level in flow interpolation path + + # ODE integration for inference + num_integration_steps: int = ( + 100 # Number of ODE integration steps (increased from 50 for smoother trajectories) + ) + integration_method: str = "euler" # ODE solver: "euler" or "rk4" + + # Timestep sampling strategy for training + # Beta distribution found to be the most effective in practice, so it is the default + timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig) + + def __post_init__(self): + if not (0.0 <= self.sigma_min <= 1.0): + raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}") + + if self.num_integration_steps <= 0: + raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}") + + if self.integration_method not in ["euler", "rk4"]: + raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}") + + +@dataclass +class TransformerConfig: + """Configuration for Transformer-based prediction model. + + These parameters control the transformer architecture used for noise/velocity + prediction in diffusion and flow matching models. + """ + + # Transformer architecture parameters + hidden_dim: int = 512 # Hidden dimension of transformer + num_layers: int = 6 # Number of transformer layers + num_heads: int = 8 # Number of attention heads + dropout: float = 0.1 # Dropout rate + use_positional_encoding: bool = True # Whether to use positional encoding + diffusion_step_embed_dim: int = 256 # Timestep embedding size + + def __post_init__(self): + """Validate Transformer-specific parameters.""" + if self.hidden_dim <= 0: + raise ValueError("hidden_dim must be positive") + + if self.num_layers <= 0: + raise ValueError("num_layers must be positive") + + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + + if self.hidden_dim % self.num_heads != 0: + raise ValueError("hidden_dim must be divisible by num_heads") + + if not (0.0 <= self.dropout <= 1.0): + raise ValueError("dropout must be between 0.0 and 1.0") + + if self.diffusion_step_embed_dim <= 0: + raise ValueError("diffusion_step_embed_dim must be positive") + + +@dataclass +class VisionEncoderConfig(draccus.ChoiceRegistry): + """Base configuration for vision encoders. + + All image preprocessing is centralized here: + 1. Resize (optional) - resize images to target resolution + 2. Crop (optional) - crop after resize, must be smaller than resize_shape + 3. Random crop - whether to use random cropping during training + """ + + use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders + + # Learning rate multiplier for vision encoder parameters + # Vision encoder learning rate = optimizer_lr * lr_multiplier + lr_multiplier: float = 0.1 + + # Image preprocessing (centralized) + resize_shape: tuple[int, int] | None = None + crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP + crop_is_random: bool = True + + def __post_init__(self): + if ( + self.resize_shape + and self.crop_shape + and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1]) + ): + raise ValueError( + f"crop_shape {self.crop_shape} must be smaller than or equal to " + f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}" + ) + + +@VisionEncoderConfig.register_subclass("dinov3") +@dataclass +class DinoV3EncoderConfig(VisionEncoderConfig): + """DinoV3 vision encoder configuration. + + DinoV3 is a self-supervised Vision Transformer trained by Meta. + CLS token usage and spatial feature extraction are handled automatically. + + Available backbones: + - vit_base_patch16_dinov3.lvd1689m (768 dims) + """ + + backbone: str = "vit_base_patch16_dinov3.lvd1689m" + + def __post_init__(self): + super().__post_init__() + # Validate backbone name + valid_backbones = [ + "vit_base_patch16_dinov3.lvd1689m", + ] + if self.backbone not in valid_backbones: + raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'") + + +@VisionEncoderConfig.register_subclass("clip") +@dataclass +class CLIPVisionEncoderConfig(VisionEncoderConfig): + """CLIP vision encoder configuration. + + CLIP is a vision-language model trained by OpenAI. + CLS token usage is handled automatically. + CLIP's internal preprocessing (resize to 224x224) can be overridden + by setting resize_shape and crop_shape. + + Available backbones: + - vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224) + """ + + backbone: str = "vit_base_patch16_clip_224.openai" + + def __post_init__(self): + super().__post_init__() + # Validate backbone name + if "clip" not in self.backbone.lower(): + raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'") + + +@dataclass +class TextEncoderConfig(draccus.ChoiceRegistry): + """Base configuration for text encoders. + + If a text encoder is set in ObservationEncoderConfig, text conditioning + is automatically enabled. + """ + + pass + + def __post_init__(self): + pass + + +@TextEncoderConfig.register_subclass("clip") +@dataclass +class CLIPTextEncoderConfig(TextEncoderConfig): + """CLIP text encoder for task conditioning. + + Uses CLIP's text encoder to embed task descriptions, which are then + used to condition the policy. The text embeddings are processed by + a learnable projection layer before being concatenated into the + conditioning vector. + """ + + model: str = "openai/clip-vit-base-patch16" + + def __post_init__(self): + super().__post_init__() + if "clip" not in self.model.lower(): + raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'") + + +@dataclass +class ObservationEncoderConfig: + """Top-level configuration for observation encoding. + + This config combines: + - Vision encoding (required): DinoV3 or CLIP vision encoder + """ + + vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig) + text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig) + + +@PreTrainedConfig.register_subclass("multi_task_dit") +@dataclass +class MultiTaskDiTConfig(PreTrainedConfig): + """ + Configuration class for the Multi-Task Diffusion Transformer (DiT) policy. + """ + + # Temporal structure - controls how the policy processes time and predicts actions + n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t) + horizon: int = 100 # predicted action steps into the future + n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start + + # Normalization strategy - critical for diffusion model performance + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision + "STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping + "ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process + } + ) + + drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1 + observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig) + transformer: TransformerConfig = field(default_factory=TransformerConfig) + objective: ObjectiveConfig = field(default_factory=DiffusionConfig) + do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation + + # training optimizer and scheduler hyperparameters + optimizer_lr: float = 2e-5 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal + scheduler_name: str = "cosine" + scheduler_warmup_steps: int = 0 # No warmup found to be optimal + + def __post_init__(self): + super().__post_init__() + + if self.drop_n_last_frames is None: + self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1 + + def get_optimizer_preset(self) -> AdamConfig: + """Return Adam optimizer configuration optimized for diffusion training. + + Note: Vision encoder learning rate is set separately via get_optim_params. + """ + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> DiffuserSchedulerConfig: + """Return learning rate scheduler configuration.""" + return DiffuserSchedulerConfig( + name=self.scheduler_name, + num_warmup_steps=self.scheduler_warmup_steps, + ) + + def validate_features(self) -> None: + """Validate that required input features are present and properly configured.""" + # Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state + # This allows for testing and simple state-only policies + + # Validate crop shape fits within image dimensions + crop_shape = self.observation_encoder.vision.crop_shape + if crop_shape is not None: + for key, image_ft in self.image_features.items(): + if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]: + raise ValueError( + f"`crop_shape` should fit within the images shapes. Got {crop_shape} " + f"for `crop_shape` and {image_ft.shape} for " + f"`{key}`." + ) + + # Ensure all images have same shape (current limitation) + if len(self.image_features) > 0: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) + + @property + def model_objective(self) -> str: + return self.objective.objective_name + + @property + def is_diffusion(self) -> bool: + return isinstance(self.objective, DiffusionConfig) + + @property + def is_flow_matching(self) -> bool: + return isinstance(self.objective, FlowMatchingConfig) + + def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig: + """Get the objective-specific configuration with proper typing.""" + return self.objective + + @property + def observation_delta_indices(self) -> list: + """Delta indices for stacking observations. Provides temporal context.""" + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + """Delta indices for action horizon prediction.""" + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + """Indices for reward deltas (not used in diffusion policy).""" + return None diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py new file mode 100644 index 000000000..e8b69c949 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Task Diffusion Transformer (DiT) Policy + +Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning. +Supports both diffusion and flow matching objectives for action generation. +""" + +from collections import deque + +import torch +from torch import Tensor + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.policies.multi_task_dit.modules.objectives import DiffusionObjective, FlowMatchingObjective +from lerobot.policies.multi_task_dit.modules.observation_encoder import ObservationEncoder +from lerobot.policies.multi_task_dit.modules.transformer import DiffusionTransformer +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import populate_queues +from lerobot.utils.constants import ACTION, OBS_IMAGES + + +class MultiTaskDiTPolicy(PreTrainedPolicy): + config_class = MultiTaskDiTConfig + name = "multi_task_dit" + + def __init__(self, config: MultiTaskDiTConfig): + super().__init__(config) + config.validate_features() + self.config = config + + self._queues = None + + self.observation_encoder = ObservationEncoder(config) + conditioning_dim = self.observation_encoder.conditioning_dim + self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim) + + action_dim = config.action_feature.shape[0] + horizon = config.horizon + + self.model_objective = config.model_objective + if config.is_diffusion: + self.objective = DiffusionObjective( + config.get_objective_config(), + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + elif config.is_flow_matching: + self.objective = FlowMatchingObjective( + config.get_objective_config(), + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + else: + raise ValueError(f"Unsupported model_objective: {self.model_objective}") + + self.reset() + + def get_optim_params(self) -> list: + """Returns parameter groups with different learning rates for vision vs non-vision parameters.""" + non_vision_params = [] + vision_encoder_params = [] + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + + if "observation_encoder.vision_encoder" in name: + vision_encoder_params.append(param) + else: + non_vision_params.append(param) + + return [ + {"params": non_vision_params}, + { + "params": vision_encoder_params, + "lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier, + }, + ] + + def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + conditioning_vec = self.observation_encoder.encode(batch) + actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) + + start_idx = n_obs_steps - 1 + end_idx = start_idx + self.config.n_action_steps + return actions[:, start_idx:end_idx] + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`.""" + self._queues = { + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), + } + + if self.config.image_features: + self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + + if self.config.env_state_feature: + self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + + if self.config.observation_encoder.text: + self._queues["task"] = deque(maxlen=self.config.n_obs_steps) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Run the batch through the model and compute the loss for training or validation.""" + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + n_obs_steps = batch["observation.state"].shape[1] + horizon = batch["action"].shape[1] + assert horizon == self.config.horizon + assert n_obs_steps == self.config.n_obs_steps + + conditioning_vec = self.observation_encoder.encode(batch) + loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) + + return loss, None + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + original_batch_keys = set(batch.keys()) + new_batch = {} + for k in self._queues: + if k in original_batch_keys: + if self._queues[k] and isinstance(self._queues[k][-1][0], str): + # for task description which is a list of strings + new_batch[k] = self._queues[k][-1] + else: + queue_values = list(self._queues[k]) + new_batch[k] = torch.stack(queue_values, dim=1) + batch = new_batch + + actions = self._generate_actions(batch) + return actions + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method manages caching of observations and actions by generating an action chunk + and returning actions from the cache until it's depleted. + """ + if ACTION in batch: + batch.pop(ACTION) + + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) + + return self._queues[ACTION].popleft() diff --git a/src/lerobot/policies/multi_task_dit/modules/__init__.py b/src/lerobot/policies/multi_task_dit/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lerobot/policies/multi_task_dit/modules/objectives.py b/src/lerobot/policies/multi_task_dit/modules/objectives.py new file mode 100644 index 000000000..32dcd592e --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modules/objectives.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains a base objective class, and implementation of the objective +classes for use in the Multi-Task Diffusion Transformer Policy. + +Architecture: +- BaseObjective: Abstract interface definition +- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective +- FlowMatchingObjective: Implements flow matching objective +""" + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from torch import Tensor + + +class BaseObjective(ABC): + """ + Base class for objectives used in Multi-Task DiT policy. + Defines the interface for training loss computation and conditional sampling. + """ + + def __init__(self, config, action_dim: int, horizon: int): + self.config = config + self.action_dim = action_dim + self.horizon = horizon + + @abstractmethod + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + """Compute training loss for the objective. + + Args: + model: The prediction network + batch: Training batch with observations and actions + conditioning_vec: Encoded observation features for conditioning + + Returns: + Scalar loss tensor + """ + pass + + @abstractmethod + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + """Generate action samples conditioned on embedded observation features. + + Args: + model: The prediction network + batch_size: The number of samples to generate + conditioning_vec: Encoded observation features for conditioning + + Returns: + Generated action sequences (batch_size, horizon, action_dim) + """ + pass + + +class DiffusionObjective(BaseObjective): + """Standard diffusion (DDPM/DDIM) objective implementation. + + Contains the noise scheduler, training loss, and conditional sampling. + """ + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__(config, action_dim, horizon) + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + # Build noise scheduler + scheduler_kwargs = { + "num_train_timesteps": config.num_train_timesteps, + "beta_start": config.beta_start, + "beta_end": config.beta_end, + "beta_schedule": config.beta_schedule, + "clip_sample": getattr(config, "clip_sample", True), + "clip_sample_range": getattr(config, "clip_sample_range", 1.0), + "prediction_type": config.prediction_type, + } + + if config.noise_scheduler_type == "DDPM": + self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs) + elif config.noise_scheduler_type == "DDIM": + self.noise_scheduler = DDIMScheduler(**scheduler_kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}") + + # Inference steps default to training steps if not provided + self.num_inference_steps = ( + config.num_inference_steps + if getattr(config, "num_inference_steps", None) is not None + else self.noise_scheduler.config.num_train_timesteps + ) + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + clean_actions = batch["action"] + noise = torch.randn_like(clean_actions) + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(clean_actions.shape[0],), + device=clean_actions.device, + ).long() + noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps) + + # Target depends on prediction type + prediction_type = self.noise_scheduler.config.prediction_type + if prediction_type == "epsilon": + target = noise + elif prediction_type == "sample": + target = clean_actions + else: + raise ValueError(f"Unsupported prediction type: {prediction_type}") + + predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted, target, reduction="none") + + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_actions = ~batch["action_is_pad"] # (B, T) + loss = loss * valid_actions.unsqueeze(-1) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + sample = torch.randn( + size=(batch_size, self.horizon, self.action_dim), + dtype=dtype, + device=device, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + for t in self.noise_scheduler.timesteps: + model_output = model( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + conditioning_vec=conditioning_vec, + ) + sample = self.noise_scheduler.step(model_output, t, sample).prev_sample + + return sample + + +class FlowMatchingObjective(BaseObjective): + """ + Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports + noise to data. This basically interpolates as path between noise and a trained distribution + with a set target velocity. + """ + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__(config, action_dim, horizon) + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: + """Sample timesteps according to configured strategy. + + Uniform: Sample t uniformly from [0,1] + Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t) + """ + if self.config.timestep_sampling.strategy_name == "uniform": + return torch.rand(batch_size, device=device) + elif self.config.timestep_sampling.strategy_name == "beta": + # Sample u ~ Beta(α, β) then transform: t = s(1-u) + # This emphasizes t near 0 (high noise) when α > β + beta_dist = torch.distributions.Beta( + self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta + ) + u = beta_dist.sample((batch_size,)).to(device) + return self.config.timestep_sampling.s * (1.0 - u) + else: + raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}") + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + """Compute flow matching training loss. + + Trains the model to predict the velocity field along linear interpolation paths. + """ + data = batch["action"] # Clean action sequences (B, T, D) + batch_size = data.shape[0] + device = data.device + + noise = torch.randn_like(data) + t = self._sample_timesteps(batch_size, device) + t_expanded = t.view(-1, 1, 1) # (B, 1, 1) for broadcasting + x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise + + # The velocity we want the model to learn: v = data - (1-σ)·noise + target_velocity = data - (1 - self.config.sigma_min) * noise + predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") + + # Optionally mask padded actions + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_mask = ~batch["action_is_pad"] # (B, T) + loss = loss * valid_mask.unsqueeze(-1) # (B, T, D) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + """Generate actions by integrating the learned velocity field via ODE. + + Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data) + """ + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + # Start from random noise at t=0 + x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device) + + # Time grid from 0 to 1 + num_steps = self.config.num_integration_steps + time_grid = torch.linspace(0, 1, num_steps + 1, device=device) + + # Integrate ODE using chosen method + if self.config.integration_method == "euler": + x = self._euler_integrate(model, x, time_grid, conditioning_vec) + elif self.config.integration_method == "rk4": + x = self._rk4_integrate(model, x, time_grid, conditioning_vec) + else: + raise ValueError(f"Unknown integration method: {self.config.integration_method}") + + return x + + def _euler_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + """ + Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n) + """ + x = x_init + + for i in range(len(time_grid) - 1): + t_scalar = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + + # Create time tensor for batch + t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device) + + # Get velocity at current point + with torch.no_grad(): + velocity = model(x, t_batch, conditioning_vec=conditioning_vec) + + # Euler step + x = x + dt * velocity + + return x + + def _rk4_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + """4th-order Runge-Kutta integration. + + Uses 4 velocity evaluations per step: + k1 = v(x, t) + k2 = v(x + dt·k1/2, t + dt/2) + k3 = v(x + dt·k2/2, t + dt/2) + k4 = v(x + dt·k3, t + dt) + x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4) + + 4x slower than Euler but more accurate. + + Note: In practice, this seems to not matter much + """ + x = x_init + + def dynamics(x_val: Tensor, t_scalar: float) -> Tensor: + """dynamics helper to get velocity at (x, t)""" + t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device) + with torch.no_grad(): + return model(x_val, t_batch, conditioning_vec=conditioning_vec) + + for i in range(len(time_grid) - 1): + t = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + + # RK4 stages + k1 = dynamics(x, t) + k2 = dynamics(x + dt * k1 / 2, t + dt / 2) + k3 = dynamics(x + dt * k2 / 2, t + dt / 2) + k4 = dynamics(x + dt * k3, t + dt) + + # Weighted combination + x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) + + return x diff --git a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py new file mode 100644 index 000000000..de39d8ee8 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Observation encoding for Multi-Task DiT policy. + +Handles vision encoding, text encoding, robot state, and environment state. +""" + +from abc import ABC, abstractmethod + +import einops +import timm +import torch +import torch.nn as nn +import torchvision +from torch import Tensor +from transformers import CLIPTextModel, CLIPTokenizer + +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE + + +class BaseVisionEncoder(ABC): + """Abstract base class for vision encoders.""" + + @abstractmethod + def forward(self, x: Tensor) -> Tensor: + """Encode RGB image to feature maps.""" + pass + + @abstractmethod + def get_output_shape(self) -> tuple: + """Get the output shape (C', H', W').""" + pass + + +class DinoV3Encoder(nn.Module, BaseVisionEncoder): + """DinoV3 vision encoder using the CLS token for global image representation.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.model_name = config.backbone + + # Create the timm model + self.model = timm.create_model( + self.model_name, + pretrained=True, + num_classes=0, + ) + + self.num_non_spatial_tokens = 5 # 1 CLS + 4 register + self.embed_dim = self.model.embed_dim + + def forward(self, x: Tensor) -> Tensor: + """Encode RGB image to feature maps.""" + # Extract all features + features = self.model.forward_features(x) # (B, total_tokens, embed_dim) + + # Use only the CLS token (first token) + cls_token = features[:, 0] # (B, embed_dim) + b, embed_dim = cls_token.shape + + # Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility + cls_features = cls_token.reshape(b, embed_dim, 1, 1) + return cls_features + + def get_output_shape(self) -> tuple: + return (self.embed_dim, 1, 1) + + +class CLIPEncoder(nn.Module, BaseVisionEncoder): + """CLIP vision encoder using the CLS token for global image representation.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.model_name = config.backbone + + # Create the timm model + self.model = timm.create_model( + self.model_name, + pretrained=True, + num_classes=0, # Remove classification head, we want features + ) + + # CLIP models have 1 CLS token (no register tokens like DinoV3) + self.num_non_spatial_tokens = 1 + + # Get embed_dim from model config + self.embed_dim = self.model.embed_dim + + def forward(self, x: Tensor) -> Tensor: + """Encode RGB image to CLS token. + + Preprocessing (resize, crop) is handled by ObservationEncoder + """ + # Extract all features + features = self.model.forward_features(x) # (B, total_tokens, embed_dim) + + # Use only the CLS token (first token) + cls_token = features[:, 0] # (B, embed_dim) + b, embed_dim = cls_token.shape + + # Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility + cls_features = cls_token.reshape(b, embed_dim, 1, 1) + return cls_features + + def get_output_shape(self) -> tuple: + return (self.embed_dim, 1, 1) + + +def create_vision_encoder(config) -> BaseVisionEncoder: + backbone_name = config.backbone.lower() + + # Check if it's a CLIP model + if "clip" in backbone_name: + return CLIPEncoder(config) + + # Check if it's a DinoV3 model + elif "dinov3" in backbone_name: + return DinoV3Encoder(config) + + else: + raise ValueError( + f"Unsupported vision backbone: {config.backbone}. " + f"Currently supported: DinoV3 models and CLIP models" + ) + + +# Registry for easy extension +VISION_ENCODER_REGISTRY: dict[str, type] = { + "dinov3": DinoV3Encoder, + "clip": CLIPEncoder, +} + + +def register_vision_encoder(name: str, encoder_class: type): + """Register a new vision encoder type. + + Args: + name: Identifier for the encoder type + encoder_class: Class implementing BaseVisionEncoder interface + """ + VISION_ENCODER_REGISTRY[name] = encoder_class + + +def get_registered_encoders() -> dict[str, type]: + """Get all registered vision encoder types. + + Returns: + Dictionary mapping encoder names to classes + """ + return VISION_ENCODER_REGISTRY.copy() + + +class CLIPTextEncoder(nn.Module): + """CLIP text encoder with frozen weights and learnable projection.""" + + def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): + super().__init__() + + self.model_name = model_name + self.projection_dim = projection_dim + + # Load CLIP text encoder and tokenizer + self.tokenizer = CLIPTokenizer.from_pretrained(model_name) + self.text_encoder = CLIPTextModel.from_pretrained(model_name) + + # Freeze all CLIP text encoder parameters + for param in self.text_encoder.parameters(): + param.requires_grad = False + + self.text_embed_dim = self.text_encoder.config.hidden_size + + # Learnable projection layer (always present, only trainable component) + self.projection = nn.Linear(self.text_embed_dim, projection_dim) + + def forward(self, text: str | list[str]) -> Tensor: + """Encode text to feature vectors. + + Args: + text: Single string or list of strings + + Returns: + Text features of shape (B, projection_dim) + """ + # handle single string input + if isinstance(text, str): + text = [text] + + text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") + + text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()} + + # encode text through CLIP (frozen) + with torch.no_grad(): + outputs = self.text_encoder(**text_inputs) + # Extract pooled output (EOS token embedding) + clip_features = outputs.pooler_output # (B, text_embed_dim) + + # project to desired dimension (trainable) + projected_features = self.projection(clip_features) # (B, projection_dim) + + return projected_features + + +class ObservationEncoder(nn.Module): + """Handles all observation processing for the conditioning vector.""" + + def __init__(self, config): + super().__init__() + self.config = config + vision_config = config.observation_encoder.vision + + self._setup_preprocessing(vision_config) + + if config.image_features: + self.num_cameras = len(config.image_features) + self.camera_names = list(config.image_features.keys()) # Preserve ordering + + if vision_config.use_separate_encoder_per_camera: + self.vision_encoders = nn.ModuleList( + [create_vision_encoder(vision_config) for _ in self.camera_names] + ) + self.vision_encoder = None + else: + self.vision_encoder = create_vision_encoder(vision_config) + self.vision_encoders = None + else: + self.vision_encoder = None + self.vision_encoders = None + self.camera_names = [] + self.num_cameras = 0 + + if hasattr(config, "robot_state_feature") and config.robot_state_feature: + self.robot_state_dim = config.robot_state_feature.shape[0] + else: + self.robot_state_dim = 0 + + if hasattr(config, "env_state_feature") and config.env_state_feature: + self.env_state_dim = config.env_state_feature.shape[0] + else: + self.env_state_dim = 0 + + text_config = config.observation_encoder.text + self.text_dim = config.transformer.hidden_dim + self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim) + + self._setup_vector_output() + + def _apply_preprocessing(self, images: Tensor) -> Tensor: + """Apply preprocessing transforms to images.""" + if self.do_resize: + images = self.resize(images) + if self.do_crop: + images = self.maybe_random_crop(images) if self.training else self.center_crop(images) + + return images + + def _setup_preprocessing(self, vision_config): + """Setup image preprocessing transforms.""" + if vision_config.resize_shape is not None: + self.do_resize = True + self.resize = torchvision.transforms.Resize( + size=vision_config.resize_shape, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True, + ) + else: + self.do_resize = False + if vision_config.crop_shape is not None: + self.do_crop = True + self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape) + if vision_config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + def _setup_vector_output(self): + """Setup for vector output.""" + total_dim = 0 + + # Vision features - get CLS token feature dimension + if self.vision_encoder is not None or self.vision_encoders is not None: + encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values())) + + # Get output shape from encoder (deterministic for CLS tokens) + feature_map_shape = encoder_to_check.get_output_shape() + c, h, w = feature_map_shape + spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim + + total_dim += spatial_feature_dim * self.num_cameras + + # State features + total_dim += self.robot_state_dim + total_dim += self.env_state_dim + + # Text features + total_dim += self.text_dim + + # Account for temporal stacking + self.conditioning_dim = total_dim * self.config.n_obs_steps + + def encode(self, batch: dict) -> Tensor: + """Encode observations to vector format.""" + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + conditioning_feats = [] + + conditioning_feats.append(batch[OBS_STATE]) + + if self.vision_encoder is not None or self.vision_encoders is not None: + images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W) + + # Handle case when n_obs=1 and time dimension might be squeezed + if len(images.shape) == 5: + # Shape is (B, N, C, H, W) - add time dimension + images = images.unsqueeze(1) # (B, 1, N, C, H, W) + + vision_config = self.config.observation_encoder.vision + if vision_config.use_separate_encoder_per_camera: + # Process each camera with its own encoder + camera_features = [] + + for cam_idx in range(self.num_cameras): + # Extract images for this camera: (B, n_obs_steps, C, H, W) + cam_images = images[:, :, cam_idx] + + # Rearrange to: (B*n_obs_steps, C, H, W) + cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w") + + # Apply preprocessing + cam_images_flat = self._apply_preprocessing(cam_images_flat) + + # Process with camera-specific encoder (direct index access) + cam_features = self.vision_encoders[cam_idx](cam_images_flat) + + # Apply spatial vectorization (flatten CLS token features) + cam_visual_features = cam_features.flatten(start_dim=1) + + # Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim) + cam_features_reshaped = einops.rearrange( + cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps + ) + camera_features.append(cam_features_reshaped) + + # Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim) + img_features = torch.cat(camera_features, dim=-1) + conditioning_feats.append(img_features) + + else: + # Shared encoder for all cameras + # Rearrange to: (B*n_obs_steps*num_cameras, C, H, W) + images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w") + + images_flat = self._apply_preprocessing(images_flat) + + visual_features = self.vision_encoder(images_flat).flatten(start_dim=1) + + # Reshape back and concatenate camera features + # (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim) + img_features = einops.rearrange( + visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras + ) + + conditioning_feats.append(img_features) + + if self.env_state_dim > 0 and OBS_ENV_STATE in batch: + conditioning_feats.append(batch[OBS_ENV_STATE]) + + if self.text_encoder is not None and "task" in batch: + text_features = self.text_encoder(batch["task"]) # (B, text_dim) + # Expand across temporal dimension to match other features + text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim) + print("Text features shape after unsqueeze and expand:", text_features.shape) + conditioning_feats.append(text_features) + + for vec in conditioning_feats: + print(f"Conditioning feature shape: {vec.shape}") + combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim) + + return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim) diff --git a/src/lerobot/policies/multi_task_dit/modules/transformer.py b/src/lerobot/policies/multi_task_dit/modules/transformer.py new file mode 100644 index 000000000..3f8415574 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modules/transformer.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer backbone for noise prediction in Multi-Task DiT policy. + +Adapted from DiT (Diffusion Transformer: https://github.com/facebookresearch/DiT) for 1D trajectory data. +""" + +import math + +import torch +from torch import Tensor, nn + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + """Modulate input with shift and scale for AdaLN-Zero. + + Args: + x: Input tensor + shift: Shift parameter + scale: Scale parameter + + Returns: + Modulated tensor: x * (1 + scale) + shift + """ + return x * (1 + scale) + shift + + +class SinusoidalPosEmb(nn.Module): + """Sinusoidal positional embeddings for timesteps. + + Identical to the reference implementation - generates smooth embeddings + for diffusion timestep values. + """ + + def __init__(self, dim: int): + """ + Args: + dim: Embedding dimension + """ + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B,) tensor of timestep values + + Returns: + (B, dim) positional embeddings + """ + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class TransformerBlock(nn.Module): + """DiT-style transformer block with AdaLN-Zero. + + Official DiT implementation with 6-parameter adaptive layer normalization: + - shift_msa, scale_msa, gate_msa: for attention block + - shift_mlp, scale_mlp, gate_mlp: for MLP block + + Reference: https://github.com/facebookresearch/DiT + """ + + def __init__( + self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0 + ): + """ + Args: + hidden_size: Hidden dimension of transformer + num_heads: Number of attention heads + num_features: Size of conditioning features + dropout: Dropout rate + """ + super().__init__() + + self.multihead_attn = nn.MultiheadAttention( + hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout + ) + + # Layer normalizations (no learnable affine parameters, all adaptation via conditioning) + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + # Feed-forward network (MLP) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size * 4, hidden_size), + ) + + # AdaLN-Zero modulation: produces 6 parameters (shift, scale, gate for attn and mlp) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True)) + + def forward(self, x: Tensor, features: Tensor) -> Tensor: + """ + Args: + x: (B, T, hidden_size) input sequence + features: (B, num_features) conditioning features + + Returns: + (B, T, hidden_size) processed sequence + """ + # Generate 6 modulation parameters from conditioning + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( + features + ).chunk(6, dim=1) + + # Attention block: norm → modulate → attn → gate × output → residual + # modulate requires unsqueeze(1) to add sequence dimension for broadcasting + attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1)) + attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block: norm → modulate → mlp → gate × output → residual + mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1)) + mlp_out = self.mlp(mlp_input) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +class DiffusionTransformer(nn.Module): + """ + Transformer-based diffusion noise prediction model. + """ + + def __init__(self, config, conditioning_dim: int): + """Initialize transformer for noise prediction. + + Args: + config: Multi-Task DiTConfig with transformer parameters + conditioning_dim: Dimension of concatenated observation features + """ + super().__init__() + + self.config = config + self.transformer_config = config.transformer + self.conditioning_dim = conditioning_dim + + self.action_dim = config.action_feature.shape[0] + self.horizon = config.horizon + self.hidden_size = self.transformer_config.hidden_dim + self.num_layers = self.transformer_config.num_layers + self.num_heads = self.transformer_config.num_heads + self.dropout = self.transformer_config.dropout + + self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(self.timestep_embed_dim), + nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim), + nn.GELU(), + nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim), + nn.GELU(), + ) + + self.cond_dim = self.timestep_embed_dim + conditioning_dim + + # Project action dimensions to hidden size + self.input_proj = nn.Linear(self.action_dim, self.hidden_size) + + if self.transformer_config.use_positional_encoding: + # Learnable positional embeddings for sequence positions + self.pos_embedding = nn.Parameter( + torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) + ) + else: + self.pos_embedding = None + + self.transformer_blocks = nn.ModuleList( + [ + TransformerBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + num_features=self.cond_dim, + dropout=self.dropout, + ) + for _ in range(self.num_layers) + ] + ) + + # Project back to action dimensions + self.output_proj = nn.Linear(self.hidden_size, self.action_dim) + + # Zero-initialize adaLN_modulation layers for AdaLN-Zero + self._initialize_weights() + + def _initialize_weights(self): + """ + Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability + """ + for block in self.transformer_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor: + """Predict noise to remove from noisy actions. + + Args: + x: (B, T, action_dim) noisy action sequences + timestep: (B,) diffusion timesteps + conditioning_vec: (B, conditioning_dim) observation features (required) + + Returns: + (B, T, action_dim) predicted noise + """ + _, seq_len, _ = x.shape + + timestep_features = self.time_mlp(timestep) # (B, timestep_embed_dim) + + # conditioning_vec is now required + cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) # (B, cond_dim) + + # Project action sequence to hidden dimension + hidden_seq = self.input_proj(x) # (B, T, hidden_size) + + if self.pos_embedding is not None: + # Add learned positional embeddings + hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] # (B, T, hidden_size) + + # Pass through transformer layers with conditioning + for block in self.transformer_blocks: + hidden_seq = block(hidden_seq, cond_features) # (B, T, hidden_size) + + # Project back to action space + output = self.output_proj(hidden_seq) # (B, T, action_dim) + + return output diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py new file mode 100644 index 000000000..bb01ae41f --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_multi_task_dit_pre_post_processors( + config: MultiTaskDiTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Creates pre- and post-processing pipelines for the Multi-Task DiT policy.""" + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + device=config.device, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/tests/policies/test_multi_task_dit_policy.py b/tests/policies/test_multi_task_dit_policy.py new file mode 100644 index 000000000..575a7a9ca --- /dev/null +++ b/tests/policies/test_multi_task_dit_policy.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for Multi-Task DiT policy. + +To run tests with GPU on Modal (temporary script): + modal run run_tests_modal.py + +To run tests locally: + python -m pytest tests/policies/test_multi_task_dit_policy.py -v +""" + +import pytest +import torch +from torch import Tensor + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import ( + DiffusionConfig, + FlowMatchingConfig, + MultiTaskDiTConfig, +) +from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.random_utils import seeded_context, set_seed + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 17 + set_seed(seed) + + +def create_train_batch( + batch_size: int = 2, + n_obs_steps: int = 2, + horizon: int = 16, + state_dim: int = 10, + action_dim: int = 10, + height: int = 224, + width: int = 224, +) -> dict[str, Tensor]: + """Create a training batch with visual input and text.""" + return { + "observation.state": torch.randn(batch_size, n_obs_steps, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width), + ACTION: torch.randn(batch_size, horizon, action_dim), + "task": ["pick up the cube"] * batch_size, + } + + +def create_observation_batch( + batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224 +) -> dict: + """Create observation batch for inference for a single timestep.""" + return { + "observation.state": torch.randn(batch_size, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width), + "task": ["pick up the red cube"] * batch_size, + } + + +def create_config( + state_dim: int = 10, + action_dim: int = 10, + n_obs_steps: int = 2, + horizon: int = 16, + n_action_steps: int = 8, + with_visual: bool = True, + height: int = 224, + width: int = 224, +) -> MultiTaskDiTConfig: + """Create a MultiTaskDiT config for testing. + + Args: + state_dim: Dimension of state observations + action_dim: Dimension of actions + n_obs_steps: Number of observation steps + horizon: Action prediction horizon + n_action_steps: Number of action steps to execute + with_visual: Whether to include visual input (default: True) + height: Image height (only used if with_visual=True) + width: Image width (only used if with_visual=True) + """ + input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))} + + if with_visual: + input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(3, height, width) + ) + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + # Use smaller model for faster tests + config.transformer.hidden_dim = 128 + config.transformer.num_layers = 2 + config.transformer.num_heads = 4 + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int): + """Test forward pass (training mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Test forward pass + loss, _ = policy.forward(batch) + assert loss is not None + assert loss.item() is not None + assert loss.shape == () + + # Test backward pass + loss.backward() + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + """Test select_action (inference mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + policy.reset() # Reset queues before inference + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_diffusion_objective(): + """Test policy with diffusion objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + config.objective = DiffusionConfig( + noise_scheduler_type="DDPM", + num_train_timesteps=100, + num_inference_steps=10, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Test forward pass + loss, _ = policy.forward(batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_flow_matching_objective(): + """Test policy with flow matching objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + config.objective = FlowMatchingConfig( + sigma_min=0.0, + num_integration_steps=10, # Use fewer steps for faster tests + integration_method="euler", + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Test forward pass + loss, _ = policy.forward(batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded correctly.""" + root = tmp_path / "test_multi_task_dit_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + + # Get device before saving + device = next(policy.parameters()).device + + policy.save_pretrained(root) + loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) + + # Explicitly move loaded_policy to the same device + loaded_policy.to(device) + loaded_policy.eval() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Move batch to the same device as the policy + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device) + + with torch.no_grad(): + with seeded_context(12): + # Collect policy values before saving + loss, _ = policy.forward(batch) + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Move observation batch to device + for key in observation_batch: + if isinstance(observation_batch[key], torch.Tensor): + observation_batch[key] = observation_batch[key].to(device) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + # Collect policy values after loading + loaded_loss, _ = loaded_policy.forward(batch) + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Move observation batch to device + for key in loaded_observation_batch: + if isinstance(loaded_observation_batch[key], torch.Tensor): + loaded_observation_batch[key] = loaded_observation_batch[key].to(device) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + # Compare state dicts + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + assert torch.allclose(loss, loaded_loss) + assert torch.allclose(actions, loaded_actions) + + +def test_multi_task_dit_policy_get_optim_params(): + """Test that the policy returns correct optimizer parameter groups.""" + config = create_config( + state_dim=10, + action_dim=10, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + + policy = MultiTaskDiTPolicy(config=config) + param_groups = policy.get_optim_params() + + # Should have 2 parameter groups: non-vision and vision encoder + assert len(param_groups) == 2 + + # First group is non-vision params (no lr specified, will use default) + assert "params" in param_groups[0] + assert len(param_groups[0]["params"]) > 0 + + # Second group is vision encoder params with different lr + assert "params" in param_groups[1] + assert "lr" in param_groups[1] + expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier + assert param_groups[1]["lr"] == expected_lr