mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Feature/add multitask diffusion transformer policy implementation (#2545)
* Add multitask diffusion transformer policy Add multitask diffusion transformer policy * expand the observation encoder to support differnt size encoders for vision and text * add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs * update readme and citations for multitask dit policy * remove dino vision encoder and simplify text and vision encoders by removing inheritance structure * adjust factory comment * update docstring for multitask dit policy processor file * simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives * add references to the modeling file comments * merge all modules files into the main modeling file * add torch.no_grad decorators * split up select action return statement * remove redundant asserts * add tutorial to training with multi_task_dit * fix bugs when testing on hardware * remove environment state conditioning * update typo in test instruction comment * add processor tests to multitask dit tests * move policy to top of file * use constants for indexing into batches and remove env state references * remove the base classes since we don't need to be able to extend * fix nit formatting in generate actions fcn * reformat and clean up tutorial for multitask dit policy * add more descriptions and depth to multitask dit tutorial * note origins of each training objective * rename config param for multiple vision encoders * refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit * add multitask dit to toc for docs * add conditional transformers import to match all other policies that use transformers lib * add test handling for multitask dit when transformers isnt available * skip tests without transformers * remove cropping of images smaller than the crop size * add kwargs arg to multitask dit constructor * add wallx dep conflict management for multitask dit policy * use hyphens for cleanliness in pyproject.toml * add conflict management to pyproject toml for pi conflict for mtdp as well * update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time * revert fast tests edits * update docs and readme files, fixing some typos and adding multitask dit to readme * chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun <skieyfly@gmail.com> * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit --------- Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: Yufei Sun <skieyfly@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
@@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -67,8 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -87,6 +87,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "multi_task_dit":
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
|
||||
return MultiTaskDiTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
@@ -147,8 +151,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -163,6 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "multi_task_dit":
|
||||
return MultiTaskDiTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
@@ -309,6 +315,16 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MultiTaskDiTConfig):
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_multi_task_dit_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
|
||||
|
||||
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]
|
||||
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
@dataclass
|
||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
|
||||
A transformer-based policy that supports both diffusion and flow matching objectives
|
||||
for multi-task robot learning with text and vision conditioning.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 2 # Number of observation steps for temporal context
|
||||
horizon: int = 32 # Number of action steps to predict
|
||||
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||
|
||||
# Objective Selection
|
||||
objective: str = "diffusion" # "diffusion" or "flow_matching"
|
||||
|
||||
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # Number of diffusion timesteps
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
|
||||
beta_start: float = 0.0001 # Starting noise level
|
||||
beta_end: float = 0.02 # Ending noise level
|
||||
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
clip_sample: bool = True # Clip samples during denoising
|
||||
clip_sample_range: float = 1.0 # Clipping range [-x, x]
|
||||
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
|
||||
|
||||
# --- Flow Matching-specific (used when objective="flow_matching") ---
|
||||
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
|
||||
|
||||
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
|
||||
|
||||
# Transformer Architecture
|
||||
hidden_dim: int = 512 # Transformer hidden dimension
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = False # Use absolute positional encoding
|
||||
timestep_embed_dim: int = 256 # Timestep embedding dimension
|
||||
use_rope: bool = True # Use Rotary Position Embedding
|
||||
rope_base: float = 10000.0 # RoPE base frequency
|
||||
|
||||
# Vision Encoder (CLIP)
|
||||
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||
image_crop_is_random: bool = True # Random crop during training, center at inference
|
||||
|
||||
# Text Encoder (CLIP)
|
||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
|
||||
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
|
||||
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
|
||||
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Training/Optimizer
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 0
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Auto-calculated
|
||||
drop_n_last_frames: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.drop_n_last_frames is None:
|
||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Validate configuration parameters."""
|
||||
# Objective validation
|
||||
if self.objective not in ["diffusion", "flow_matching"]:
|
||||
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
|
||||
|
||||
# Transformer validation
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
# Vision encoder validation
|
||||
if "clip" not in self.vision_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
|
||||
)
|
||||
if (
|
||||
self.image_resize_shape
|
||||
and self.image_crop_shape
|
||||
and (
|
||||
self.image_crop_shape[0] > self.image_resize_shape[0]
|
||||
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||
)
|
||||
):
|
||||
logging.warning(
|
||||
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
self.image_resize_shape,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
|
||||
# Text encoder validation
|
||||
if "clip" not in self.text_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
|
||||
)
|
||||
|
||||
# Objective-specific validation
|
||||
if self.objective == "diffusion":
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
|
||||
|
||||
elif self.objective == "flow_matching":
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(
|
||||
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
|
||||
)
|
||||
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
|
||||
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
|
||||
if self.timestep_sampling_strategy == "beta":
|
||||
if not (0.0 < self.timestep_sampling_s <= 1.0):
|
||||
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
|
||||
if self.timestep_sampling_alpha <= 0:
|
||||
raise ValueError("timestep_sampling_alpha must be positive")
|
||||
if self.timestep_sampling_beta <= 0:
|
||||
raise ValueError("timestep_sampling_beta must be positive")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# If the configured crop doesn't fit, disable cropping instead of erroring.
|
||||
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
|
||||
if self.image_crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
# image_ft.shape is (C, H, W)
|
||||
effective_h, effective_w = (
|
||||
self.image_resize_shape
|
||||
if self.image_resize_shape is not None
|
||||
else (image_ft.shape[1], image_ft.shape[2])
|
||||
)
|
||||
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
|
||||
logging.warning(
|
||||
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
effective_h,
|
||||
effective_w,
|
||||
key,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
break
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
first_key, first_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_ft.shape:
|
||||
raise ValueError(
|
||||
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_diffusion(self) -> bool:
|
||||
return self.objective == "diffusion"
|
||||
|
||||
@property
|
||||
def is_flow_matching(self) -> bool:
|
||||
return self.objective == "flow_matching"
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,803 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-Task Diffusion Transformer (DiT) Policy
|
||||
|
||||
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
|
||||
Supports both diffusion and flow matching objectives for action generation.
|
||||
|
||||
References:
|
||||
- https://arxiv.org/abs/2507.05331
|
||||
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
|
||||
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import CLIPTextModel, CLIPVisionModel
|
||||
else:
|
||||
CLIPTextModel = None
|
||||
CLIPVisionModel = None
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
# -- Policy --
|
||||
|
||||
|
||||
class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
config_class = MultiTaskDiTConfig
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.observation_encoder = ObservationEncoder(config)
|
||||
conditioning_dim = self.observation_encoder.conditioning_dim
|
||||
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
|
||||
|
||||
action_dim = config.action_feature.shape[0]
|
||||
horizon = config.horizon
|
||||
|
||||
if config.is_diffusion:
|
||||
self.objective = DiffusionObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
elif config.is_flow_matching:
|
||||
self.objective = FlowMatchingObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported objective: {config.objective}")
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> list:
|
||||
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
|
||||
non_vision_params = []
|
||||
vision_encoder_params = []
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if "observation_encoder.vision_encoder" in name:
|
||||
vision_encoder_params.append(param)
|
||||
else:
|
||||
non_vision_params.append(param)
|
||||
|
||||
return [
|
||||
{"params": non_vision_params},
|
||||
{
|
||||
"params": vision_encoder_params,
|
||||
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
|
||||
},
|
||||
]
|
||||
|
||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
|
||||
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations"""
|
||||
self.eval()
|
||||
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Prepare batch by stacking image features if needed."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations"""
|
||||
if ACTION in batch:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training"""
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||
|
||||
return loss, None
|
||||
|
||||
|
||||
# -- Observation Encoders --
|
||||
|
||||
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||
self.num_non_spatial_tokens = 1
|
||||
self.embed_dim = self.model.config.hidden_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token."""
|
||||
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||
cls_token = outputs.last_hidden_state[:, 0]
|
||||
b, embed_dim = cls_token.shape
|
||||
return cls_token.reshape(b, embed_dim, 1, 1)
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer.
|
||||
|
||||
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
|
||||
pipeline to see how the tokenization is handled.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.projection_dim = projection_dim
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Encode pre-tokenized text to feature vectors."""
|
||||
# Ensure inputs are on the same device as the model
|
||||
device = next(self.parameters()).device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
clip_features = outputs.pooler_output
|
||||
|
||||
return self.projection(clip_features)
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""Handles all observation processing for the conditioning vector."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._setup_preprocessing(config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if config.use_separate_rgb_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
self.vision_encoders = None
|
||||
self.camera_names = []
|
||||
self.num_cameras = 0
|
||||
|
||||
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
|
||||
self.robot_state_dim = config.robot_state_feature.shape[0]
|
||||
else:
|
||||
self.robot_state_dim = 0
|
||||
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
def _apply_preprocessing(self, images: Tensor) -> Tensor:
|
||||
if self.do_resize:
|
||||
images = self.resize(images)
|
||||
if self.do_crop:
|
||||
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, config):
|
||||
if config.image_resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=config.image_resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
|
||||
if config.image_crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||
if config.image_crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
total_dim = 0
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
c, h, w = feature_map_shape
|
||||
spatial_feature_dim = c * h * w
|
||||
total_dim += spatial_feature_dim * self.num_cameras
|
||||
|
||||
total_dim += self.robot_state_dim
|
||||
total_dim += self.text_dim
|
||||
|
||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||
|
||||
def encode(self, batch: dict) -> Tensor:
|
||||
"""Encode observations to vector format."""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
conditioning_feats = []
|
||||
|
||||
conditioning_feats.append(batch[OBS_STATE])
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
images = batch[OBS_IMAGES]
|
||||
|
||||
if len(images.shape) == 5:
|
||||
images = images.unsqueeze(1)
|
||||
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
camera_features = []
|
||||
for cam_idx in range(self.num_cameras):
|
||||
cam_images = images[:, :, cam_idx]
|
||||
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
|
||||
cam_images_flat = self._apply_preprocessing(cam_images_flat)
|
||||
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
|
||||
cam_visual_features = cam_features.flatten(start_dim=1)
|
||||
cam_features_reshaped = einops.rearrange(
|
||||
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
camera_features.append(cam_features_reshaped)
|
||||
img_features = torch.cat(camera_features, dim=-1)
|
||||
conditioning_feats.append(img_features)
|
||||
else:
|
||||
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
|
||||
images_flat = self._apply_preprocessing(images_flat)
|
||||
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
|
||||
img_features = einops.rearrange(
|
||||
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
|
||||
)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
|
||||
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
|
||||
|
||||
text_features = self.text_encoder(input_ids, attention_mask)
|
||||
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1)
|
||||
return combined_features.flatten(start_dim=1)
|
||||
|
||||
|
||||
# -- Transformer Components --
|
||||
|
||||
|
||||
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
||||
"""Modulate input with shift and scale for AdaLN-Zero."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embeddings for timesteps."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE) for transformers."""
|
||||
|
||||
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
|
||||
super().__init__()
|
||||
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._precompute_cache(max_seq_len)
|
||||
|
||||
def _precompute_cache(self, seq_len: int):
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def _rotate_half(self, x: Tensor) -> Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
|
||||
seq_len = q.shape[2]
|
||||
if seq_len > self.max_seq_len:
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
|
||||
|
||||
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
|
||||
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
|
||||
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
|
||||
return q_rotated, k_rotated
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, T, _ = x.shape # noqa: N806
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
|
||||
)
|
||||
|
||||
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""DiT-style transformer block with AdaLN-Zero."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 128,
|
||||
num_heads: int = 4,
|
||||
num_features: int = 128,
|
||||
dropout: float = 0.0,
|
||||
use_rope: bool = False,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
|
||||
if use_rope:
|
||||
self.attn = RoPEAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
max_seq_len=max_seq_len,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
else:
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size * 4),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size * 4, hidden_size),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, features: Tensor) -> Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
features
|
||||
).chunk(6, dim=1)
|
||||
|
||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||
|
||||
if self.use_rope:
|
||||
attn_out = self.attn(attn_input)
|
||||
else:
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||
|
||||
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
|
||||
mlp_out = self.mlp(mlp_input)
|
||||
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
"""Transformer-based diffusion noise prediction model."""
|
||||
|
||||
def __init__(self, config, conditioning_dim: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.conditioning_dim = conditioning_dim
|
||||
|
||||
self.action_dim = config.action_feature.shape[0]
|
||||
self.horizon = config.horizon
|
||||
self.hidden_size = config.hidden_dim
|
||||
self.num_layers = config.num_layers
|
||||
self.num_heads = config.num_heads
|
||||
self.dropout = config.dropout
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
self.timestep_embed_dim = config.timestep_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.cond_dim = self.timestep_embed_dim + conditioning_dim
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if config.use_positional_encoding:
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
use_rope=self.use_rope,
|
||||
max_seq_len=self.horizon,
|
||||
rope_base=config.rope_base,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
|
||||
_, seq_len, _ = x.shape
|
||||
|
||||
timestep_features = self.time_mlp(timestep)
|
||||
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
|
||||
|
||||
hidden_seq = self.input_proj(x)
|
||||
|
||||
if self.pos_embedding is not None:
|
||||
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_seq = block(hidden_seq, cond_features)
|
||||
|
||||
return self.output_proj(hidden_seq)
|
||||
|
||||
|
||||
# -- Objectives --
|
||||
|
||||
|
||||
class DiffusionObjective(nn.Module):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
scheduler_kwargs = {
|
||||
"num_train_timesteps": config.num_train_timesteps,
|
||||
"beta_start": config.beta_start,
|
||||
"beta_end": config.beta_end,
|
||||
"beta_schedule": config.beta_schedule,
|
||||
"clip_sample": config.clip_sample,
|
||||
"clip_sample_range": config.clip_sample_range,
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
|
||||
|
||||
self.num_inference_steps = (
|
||||
config.num_inference_steps
|
||||
if config.num_inference_steps is not None
|
||||
else self.noise_scheduler.config.num_train_timesteps
|
||||
)
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
clean_actions = batch[ACTION]
|
||||
noise = torch.randn_like(clean_actions)
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(clean_actions.shape[0],),
|
||||
device=clean_actions.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
|
||||
|
||||
prediction_type = self.noise_scheduler.config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif prediction_type == "sample":
|
||||
target = clean_actions
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {prediction_type}")
|
||||
|
||||
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.horizon, self.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
model_output = model(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
conditioning_vec=conditioning_vec,
|
||||
)
|
||||
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlowMatchingObjective(nn.Module):
|
||||
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
if self.config.timestep_sampling_strategy == "uniform":
|
||||
return torch.rand(batch_size, device=device)
|
||||
elif self.config.timestep_sampling_strategy == "beta":
|
||||
beta_dist = torch.distributions.Beta(
|
||||
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
|
||||
)
|
||||
u = beta_dist.sample((batch_size,)).to(device)
|
||||
return self.config.timestep_sampling_s * (1.0 - u)
|
||||
else:
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
data = batch[ACTION]
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
|
||||
noise = torch.randn_like(data)
|
||||
t = self._sample_timesteps(batch_size, device)
|
||||
t_expanded = t.view(-1, 1, 1)
|
||||
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
|
||||
|
||||
target_velocity = data - (1 - self.config.sigma_min) * noise
|
||||
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
|
||||
|
||||
num_steps = self.config.num_integration_steps
|
||||
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
|
||||
|
||||
if self.config.integration_method == "euler":
|
||||
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
|
||||
elif self.config.integration_method == "rk4":
|
||||
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
|
||||
else:
|
||||
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
|
||||
|
||||
return x
|
||||
|
||||
def _euler_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
for i in range(len(time_grid) - 1):
|
||||
t_scalar = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
|
||||
with torch.no_grad():
|
||||
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
|
||||
x = x + dt * velocity
|
||||
return x
|
||||
|
||||
def _rk4_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
|
||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||
with torch.no_grad():
|
||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
t = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
|
||||
k1 = dynamics(x, t)
|
||||
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
|
||||
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
|
||||
k4 = dynamics(x + dt * k3, t + dt)
|
||||
|
||||
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_multi_task_dit_pre_post_processors(
|
||||
config: MultiTaskDiTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features.
|
||||
2. Adding a batch dimension.
|
||||
3. Tokenizing the language task description (if present).
|
||||
4. Moving the data to the specified device.
|
||||
5. Normalizing the input and output features based on dataset statistics.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output features to their original scale.
|
||||
2. Moving the data to the CPU.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Multi-Task DiT policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_encoder_name,
|
||||
padding=config.tokenizer_padding,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
max_length=config.tokenizer_max_length,
|
||||
truncation=config.tokenizer_truncation,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user