Add multitask diffusion transformer policy

Add multitask diffusion transformer policy
This commit is contained in:
Bryson Jones
2025-11-12 16:20:59 -08:00
committed by GitHub
parent a5b29d4301
commit 14a7a4d7d4
13 changed files with 2080 additions and 4 deletions
+1
View File
@@ -118,6 +118,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
+1 -1
View File
@@ -157,7 +157,7 @@ available_datasets = sorted(
)
# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/robots`
available_robots = [
+2
View File
@@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
@@ -25,6 +26,7 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"SmolVLAConfig",
+20 -3
View File
@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
Returns:
The policy class corresponding to the given name.
@@ -79,6 +80,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -120,8 +125,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"reward_classifier", "smolvla".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -136,6 +141,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -274,6 +281,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
@@ -0,0 +1,45 @@
# Multi-Task DiT Policy
For details describing the architecture, see the citations and the blog post from Bryson Jones: https://brysonkjones.substack.com/p/dissecting-multi-task-diffusion-transformer-policy
## Trainining and Inference Baseline Recommendations:
### Training
- Number of demonstrations: >100 per task
- Batch Size: 320
- Objective: Diffusion
- Cameras: At least two, with one egocentric view per arm
### Inference
- GPU: 5070 Ti or above in performance
- Sampling:
- Strategy: DDIM
- Number of Timesteps: 10
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
@@ -0,0 +1,433 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import draccus
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@dataclass
class ObjectiveConfig(draccus.ChoiceRegistry):
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
pass
@ObjectiveConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(ObjectiveConfig):
"""Configuration for standard diffusion model training and inference.
These parameters control the noise scheduling and denoising process for
standard DDPM/DDIM diffusion models.
"""
objective_name: str = field(default="diffusion", init=False)
# Noise scheduler configuration - controls diffusion process
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
beta_start: float = 0.0001 # Small initial noise level
beta_end: float = 0.02 # Moderate final noise level
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
clip_sample: bool = True # Prevent extreme action values
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
# Inference configuration
num_inference_steps: int | None = None # Default to num_train_timesteps
def __post_init__(self):
"""Validate diffusion-specific parameters."""
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
f"got {self.beta_start}, {self.beta_end}"
)
@dataclass
class TimestepSamplingConfig(draccus.ChoiceRegistry):
"""Base configuration for timestep sampling strategies during training."""
pass
@TimestepSamplingConfig.register_subclass("uniform")
@dataclass
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
"""Uniform timestep sampling from [0, 1]."""
strategy_name: str = field(default="uniform", init=False)
@TimestepSamplingConfig.register_subclass("beta")
@dataclass
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
"""Beta distribution timestep sampling.
Samples from Beta distribution emphasizing low timesteps (high noise).
This was inspired on the work from Physical Intelligence PI-0 model,
where they suggested the beta distribution for sampling timesteps
during training improved sample quality.
"""
strategy_name: str = field(default="beta", init=False)
s: float = 0.999 # Max timestep threshold for beta sampling
alpha: float = 1.5 # Beta distribution alpha parameter
beta: float = 1.0 # Beta distribution beta parameter
def __post_init__(self):
if not (0.0 < self.s <= 1.0):
raise ValueError(f"s must be in (0, 1], got {self.s}")
if self.alpha <= 0:
raise ValueError(f"alpha must be positive, got {self.alpha}")
if self.beta <= 0:
raise ValueError(f"beta must be positive, got {self.beta}")
@ObjectiveConfig.register_subclass("flow_matching")
@dataclass
class FlowMatchingConfig(ObjectiveConfig):
"""Configuration for flow matching training and inference.
These parameters control the velocity field learning and ODE integration
process for flow matching models.
"""
objective_name: str = field(default="flow_matching", init=False)
# Flow path construction
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
# ODE integration for inference
num_integration_steps: int = (
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
)
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
# Timestep sampling strategy for training
# Beta distribution found to be the most effective in practice, so it is the default
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
def __post_init__(self):
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
@dataclass
class TransformerConfig:
"""Configuration for Transformer-based prediction model.
These parameters control the transformer architecture used for noise/velocity
prediction in diffusion and flow matching models.
"""
# Transformer architecture parameters
hidden_dim: int = 512 # Hidden dimension of transformer
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = True # Whether to use positional encoding
diffusion_step_embed_dim: int = 256 # Timestep embedding size
def __post_init__(self):
"""Validate Transformer-specific parameters."""
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
if self.diffusion_step_embed_dim <= 0:
raise ValueError("diffusion_step_embed_dim must be positive")
@dataclass
class VisionEncoderConfig(draccus.ChoiceRegistry):
"""Base configuration for vision encoders.
All image preprocessing is centralized here:
1. Resize (optional) - resize images to target resolution
2. Crop (optional) - crop after resize, must be smaller than resize_shape
3. Random crop - whether to use random cropping during training
"""
use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders
# Learning rate multiplier for vision encoder parameters
# Vision encoder learning rate = optimizer_lr * lr_multiplier
lr_multiplier: float = 0.1
# Image preprocessing (centralized)
resize_shape: tuple[int, int] | None = None
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
crop_is_random: bool = True
def __post_init__(self):
if (
self.resize_shape
and self.crop_shape
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
):
raise ValueError(
f"crop_shape {self.crop_shape} must be smaller than or equal to "
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
)
@VisionEncoderConfig.register_subclass("dinov3")
@dataclass
class DinoV3EncoderConfig(VisionEncoderConfig):
"""DinoV3 vision encoder configuration.
DinoV3 is a self-supervised Vision Transformer trained by Meta.
CLS token usage and spatial feature extraction are handled automatically.
Available backbones:
- vit_base_patch16_dinov3.lvd1689m (768 dims)
"""
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
def __post_init__(self):
super().__post_init__()
# Validate backbone name
valid_backbones = [
"vit_base_patch16_dinov3.lvd1689m",
]
if self.backbone not in valid_backbones:
raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'")
@VisionEncoderConfig.register_subclass("clip")
@dataclass
class CLIPVisionEncoderConfig(VisionEncoderConfig):
"""CLIP vision encoder configuration.
CLIP is a vision-language model trained by OpenAI.
CLS token usage is handled automatically.
CLIP's internal preprocessing (resize to 224x224) can be overridden
by setting resize_shape and crop_shape.
Available backbones:
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
"""
backbone: str = "vit_base_patch16_clip_224.openai"
def __post_init__(self):
super().__post_init__()
# Validate backbone name
if "clip" not in self.backbone.lower():
raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'")
@dataclass
class TextEncoderConfig(draccus.ChoiceRegistry):
"""Base configuration for text encoders.
If a text encoder is set in ObservationEncoderConfig, text conditioning
is automatically enabled.
"""
pass
def __post_init__(self):
pass
@TextEncoderConfig.register_subclass("clip")
@dataclass
class CLIPTextEncoderConfig(TextEncoderConfig):
"""CLIP text encoder for task conditioning.
Uses CLIP's text encoder to embed task descriptions, which are then
used to condition the policy. The text embeddings are processed by
a learnable projection layer before being concatenated into the
conditioning vector.
"""
model: str = "openai/clip-vit-base-patch16"
def __post_init__(self):
super().__post_init__()
if "clip" not in self.model.lower():
raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'")
@dataclass
class ObservationEncoderConfig:
"""Top-level configuration for observation encoding.
This config combines:
- Vision encoding (required): DinoV3 or CLIP vision encoder
"""
vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig)
text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig)
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
"""
# Temporal structure - controls how the policy processes time and predicts actions
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t)
horizon: int = 100 # predicted action steps into the future
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
# Normalization strategy - critical for diffusion model performance
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process
}
)
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
transformer: TransformerConfig = field(default_factory=TransformerConfig)
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
# training optimizer and scheduler hyperparameters
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0 # No warmup found to be optimal
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
def get_optimizer_preset(self) -> AdamConfig:
"""Return Adam optimizer configuration optimized for diffusion training.
Note: Vision encoder learning rate is set separately via get_optim_params.
"""
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
"""Return learning rate scheduler configuration."""
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state
# This allows for testing and simple state-only policies
# Validate crop shape fits within image dimensions
crop_shape = self.observation_encoder.vision.crop_shape
if crop_shape is not None:
for key, image_ft in self.image_features.items():
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Ensure all images have same shape (current limitation)
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def model_objective(self) -> str:
return self.objective.objective_name
@property
def is_diffusion(self) -> bool:
return isinstance(self.objective, DiffusionConfig)
@property
def is_flow_matching(self) -> bool:
return isinstance(self.objective, FlowMatchingConfig)
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
"""Get the objective-specific configuration with proper typing."""
return self.objective
@property
def observation_delta_indices(self) -> list:
"""Delta indices for stacking observations. Provides temporal context."""
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
"""Delta indices for action horizon prediction."""
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
"""Indices for reward deltas (not used in diffusion policy)."""
return None
@@ -0,0 +1,178 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
"""
from collections import deque
import torch
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modules.objectives import DiffusionObjective, FlowMatchingObjective
from lerobot.policies.multi_task_dit.modules.observation_encoder import ObservationEncoder
from lerobot.policies.multi_task_dit.modules.transformer import DiffusionTransformer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
self.model_objective = config.model_objective
if config.is_diffusion:
self.objective = DiffusionObjective(
config.get_objective_config(),
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config.get_objective_config(),
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported model_objective: {self.model_objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters."""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start_idx = n_obs_steps - 1
end_idx = start_idx + self.config.n_action_steps
return actions[:, start_idx:end_idx]
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`."""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
if self.config.observation_encoder.text:
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
original_batch_keys = set(batch.keys())
new_batch = {}
for k in self._queues:
if k in original_batch_keys:
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
# for task description which is a list of strings
new_batch[k] = self._queues[k][-1]
else:
queue_values = list(self._queues[k])
new_batch[k] = torch.stack(queue_values, dim=1)
batch = new_batch
actions = self._generate_actions(batch)
return actions
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method manages caching of observations and actions by generating an action chunk
and returning actions from the cache until it's depleted.
"""
if ACTION in batch:
batch.pop(ACTION)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
return self._queues[ACTION].popleft()
@@ -0,0 +1,305 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains a base objective class, and implementation of the objective
classes for use in the Multi-Task Diffusion Transformer Policy.
Architecture:
- BaseObjective: Abstract interface definition
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
- FlowMatchingObjective: Implements flow matching objective
"""
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
class BaseObjective(ABC):
"""
Base class for objectives used in Multi-Task DiT policy.
Defines the interface for training loss computation and conditional sampling.
"""
def __init__(self, config, action_dim: int, horizon: int):
self.config = config
self.action_dim = action_dim
self.horizon = horizon
@abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute training loss for the objective.
Args:
model: The prediction network
batch: Training batch with observations and actions
conditioning_vec: Encoded observation features for conditioning
Returns:
Scalar loss tensor
"""
pass
@abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate action samples conditioned on embedded observation features.
Args:
model: The prediction network
batch_size: The number of samples to generate
conditioning_vec: Encoded observation features for conditioning
Returns:
Generated action sequences (batch_size, horizon, action_dim)
"""
pass
class DiffusionObjective(BaseObjective):
"""Standard diffusion (DDPM/DDIM) objective implementation.
Contains the noise scheduler, training loss, and conditional sampling.
"""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
# Build noise scheduler
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": getattr(config, "clip_sample", True),
"clip_sample_range": getattr(config, "clip_sample_range", 1.0),
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
# Inference steps default to training steps if not provided
self.num_inference_steps = (
config.num_inference_steps
if getattr(config, "num_inference_steps", None) is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch["action"]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
# Target depends on prediction type
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"] # (B, T)
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(BaseObjective):
"""
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
noise to data. This basically interpolates as path between noise and a trained distribution
with a set target velocity.
"""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
"""Sample timesteps according to configured strategy.
Uniform: Sample t uniformly from [0,1]
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
"""
if self.config.timestep_sampling.strategy_name == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling.strategy_name == "beta":
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
# This emphasizes t near 0 (high noise) when α > β
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling.s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute flow matching training loss.
Trains the model to predict the velocity field along linear interpolation paths.
"""
data = batch["action"] # Clean action sequences (B, T, D)
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1) # (B, 1, 1) for broadcasting
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
# The velocity we want the model to learn: v = data - (1-σ)·noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
# Optionally mask padded actions
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"] # (B, T)
loss = loss * valid_mask.unsqueeze(-1) # (B, T, D)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate actions by integrating the learned velocity field via ODE.
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Start from random noise at t=0
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
# Time grid from 0 to 1
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
# Integrate ODE using chosen method
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
"""
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
# Create time tensor for batch
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
# Get velocity at current point
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
# Euler step
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""4th-order Runge-Kutta integration.
Uses 4 velocity evaluations per step:
k1 = v(x, t)
k2 = v(x + dt·k1/2, t + dt/2)
k3 = v(x + dt·k2/2, t + dt/2)
k4 = v(x + dt·k3, t + dt)
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
4x slower than Euler but more accurate.
Note: In practice, this seems to not matter much
"""
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
"""dynamics helper to get velocity at (x, t)"""
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
# RK4 stages
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
# Weighted combination
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x
@@ -0,0 +1,396 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Observation encoding for Multi-Task DiT policy.
Handles vision encoding, text encoding, robot state, and environment state.
"""
from abc import ABC, abstractmethod
import einops
import timm
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class BaseVisionEncoder(ABC):
"""Abstract base class for vision encoders."""
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
pass
@abstractmethod
def get_output_shape(self) -> tuple:
"""Get the output shape (C', H', W')."""
pass
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
"""DinoV3 vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0,
)
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPEncoder(nn.Module, BaseVisionEncoder):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0, # Remove classification head, we want features
)
# CLIP models have 1 CLS token (no register tokens like DinoV3)
self.num_non_spatial_tokens = 1
# Get embed_dim from model config
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token.
Preprocessing (resize, crop) is handled by ObservationEncoder
"""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
def create_vision_encoder(config) -> BaseVisionEncoder:
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
if "clip" in backbone_name:
return CLIPEncoder(config)
# Check if it's a DinoV3 model
elif "dinov3" in backbone_name:
return DinoV3Encoder(config)
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: DinoV3 models and CLIP models"
)
# Registry for easy extension
VISION_ENCODER_REGISTRY: dict[str, type] = {
"dinov3": DinoV3Encoder,
"clip": CLIPEncoder,
}
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type.
Args:
name: Identifier for the encoder type
encoder_class: Class implementing BaseVisionEncoder interface
"""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types.
Returns:
Dictionary mapping encoder names to classes
"""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and learnable projection."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
# Load CLIP text encoder and tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
# Freeze all CLIP text encoder parameters
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
# Learnable projection layer (always present, only trainable component)
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor:
"""Encode text to feature vectors.
Args:
text: Single string or list of strings
Returns:
Text features of shape (B, projection_dim)
"""
# handle single string input
if isinstance(text, str):
text = [text]
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
# encode text through CLIP (frozen)
with torch.no_grad():
outputs = self.text_encoder(**text_inputs)
# Extract pooled output (EOS token embedding)
clip_features = outputs.pooler_output # (B, text_embed_dim)
# project to desired dimension (trainable)
projected_features = self.projection(clip_features) # (B, projection_dim)
return projected_features
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
vision_config = config.observation_encoder.vision
self._setup_preprocessing(vision_config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) # Preserve ordering
if vision_config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[create_vision_encoder(vision_config) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = create_vision_encoder(vision_config)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
if hasattr(config, "env_state_feature") and config.env_state_feature:
self.env_state_dim = config.env_state_feature.shape[0]
else:
self.env_state_dim = 0
text_config = config.observation_encoder.text
self.text_dim = config.transformer.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
"""Apply preprocessing transforms to images."""
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, vision_config):
"""Setup image preprocessing transforms."""
if vision_config.resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=vision_config.resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if vision_config.crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
if vision_config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
"""Setup for vector output."""
total_dim = 0
# Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
# Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim
total_dim += spatial_feature_dim * self.num_cameras
# State features
total_dim += self.robot_state_dim
total_dim += self.env_state_dim
# Text features
total_dim += self.text_dim
# Account for temporal stacking
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W)
# Handle case when n_obs=1 and time dimension might be squeezed
if len(images.shape) == 5:
# Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
vision_config = self.config.observation_encoder.vision
if vision_config.use_separate_encoder_per_camera:
# Process each camera with its own encoder
camera_features = []
for cam_idx in range(self.num_cameras):
# Extract images for this camera: (B, n_obs_steps, C, H, W)
cam_images = images[:, :, cam_idx]
# Rearrange to: (B*n_obs_steps, C, H, W)
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
# Apply preprocessing
cam_images_flat = self._apply_preprocessing(cam_images_flat)
# Process with camera-specific encoder (direct index access)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
# Apply spatial vectorization (flatten CLS token features)
cam_visual_features = cam_features.flatten(start_dim=1)
# Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
# Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
# Shared encoder for all cameras
# Rearrange to: (B*n_obs_steps*num_cameras, C, H, W)
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
# Reshape back and concatenate camera features
# (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
conditioning_feats.append(batch[OBS_ENV_STATE])
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
# Expand across temporal dimension to match other features
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
print("Text features shape after unsqueeze and expand:", text_features.shape)
conditioning_feats.append(text_features)
for vec in conditioning_feats:
print(f"Conditioning feature shape: {vec.shape}")
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
@@ -0,0 +1,247 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer backbone for noise prediction in Multi-Task DiT policy.
Adapted from DiT (Diffusion Transformer: https://github.com/facebookresearch/DiT) for 1D trajectory data.
"""
import math
import torch
from torch import Tensor, nn
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero.
Args:
x: Input tensor
shift: Shift parameter
scale: Scale parameter
Returns:
Modulated tensor: x * (1 + scale) + shift
"""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps.
Identical to the reference implementation - generates smooth embeddings
for diffusion timestep values.
"""
def __init__(self, dim: int):
"""
Args:
dim: Embedding dimension
"""
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B,) tensor of timestep values
Returns:
(B, dim) positional embeddings
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero.
Official DiT implementation with 6-parameter adaptive layer normalization:
- shift_msa, scale_msa, gate_msa: for attention block
- shift_mlp, scale_mlp, gate_mlp: for MLP block
Reference: https://github.com/facebookresearch/DiT
"""
def __init__(
self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0
):
"""
Args:
hidden_size: Hidden dimension of transformer
num_heads: Number of attention heads
num_features: Size of conditioning features
dropout: Dropout rate
"""
super().__init__()
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# Feed-forward network (MLP)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
# AdaLN-Zero modulation: produces 6 parameters (shift, scale, gate for attn and mlp)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
"""
Args:
x: (B, T, hidden_size) input sequence
features: (B, num_features) conditioning features
Returns:
(B, T, hidden_size) processed sequence
"""
# Generate 6 modulation parameters from conditioning
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
# Attention block: norm → modulate → attn → gate × output → residual
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
# MLP block: norm → modulate → mlp → gate × output → residual
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""
Transformer-based diffusion noise prediction model.
"""
def __init__(self, config, conditioning_dim: int):
"""Initialize transformer for noise prediction.
Args:
config: Multi-Task DiTConfig with transformer parameters
conditioning_dim: Dimension of concatenated observation features
"""
super().__init__()
self.config = config
self.transformer_config = config.transformer
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = self.transformer_config.hidden_dim
self.num_layers = self.transformer_config.num_layers
self.num_heads = self.transformer_config.num_heads
self.dropout = self.transformer_config.dropout
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
# Project action dimensions to hidden size
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if self.transformer_config.use_positional_encoding:
# Learnable positional embeddings for sequence positions
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
)
for _ in range(self.num_layers)
]
)
# Project back to action dimensions
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
# Zero-initialize adaLN_modulation layers for AdaLN-Zero
self._initialize_weights()
def _initialize_weights(self):
"""
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
"""
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
"""Predict noise to remove from noisy actions.
Args:
x: (B, T, action_dim) noisy action sequences
timestep: (B,) diffusion timesteps
conditioning_vec: (B, conditioning_dim) observation features (required)
Returns:
(B, T, action_dim) predicted noise
"""
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep) # (B, timestep_embed_dim)
# conditioning_vec is now required
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) # (B, cond_dim)
# Project action sequence to hidden dimension
hidden_seq = self.input_proj(x) # (B, T, hidden_size)
if self.pos_embedding is not None:
# Add learned positional embeddings
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] # (B, T, hidden_size)
# Pass through transformer layers with conditioning
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features) # (B, T, hidden_size)
# Project back to action space
output = self.output_proj(hidden_seq) # (B, T, action_dim)
return output
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Creates pre- and post-processing pipelines for the Multi-Task DiT policy."""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,377 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for Multi-Task DiT policy.
To run tests with GPU on Modal (temporary script):
modal run run_tests_modal.py
To run tests locally:
python -m pytest tests/policies/test_multi_task_dit_policy.py -v
"""
import pytest
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
DiffusionConfig,
FlowMatchingConfig,
MultiTaskDiTConfig,
)
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 17
set_seed(seed)
def create_train_batch(
batch_size: int = 2,
n_obs_steps: int = 2,
horizon: int = 16,
state_dim: int = 10,
action_dim: int = 10,
height: int = 224,
width: int = 224,
) -> dict[str, Tensor]:
"""Create a training batch with visual input and text."""
return {
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
ACTION: torch.randn(batch_size, horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def create_observation_batch(
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
) -> dict:
"""Create observation batch for inference for a single timestep."""
return {
"observation.state": torch.randn(batch_size, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
"task": ["pick up the red cube"] * batch_size,
}
def create_config(
state_dim: int = 10,
action_dim: int = 10,
n_obs_steps: int = 2,
horizon: int = 16,
n_action_steps: int = 8,
with_visual: bool = True,
height: int = 224,
width: int = 224,
) -> MultiTaskDiTConfig:
"""Create a MultiTaskDiT config for testing.
Args:
state_dim: Dimension of state observations
action_dim: Dimension of actions
n_obs_steps: Number of observation steps
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
with_visual: Whether to include visual input (default: True)
height: Image height (only used if with_visual=True)
width: Image width (only used if with_visual=True)
"""
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
if with_visual:
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(3, height, width)
)
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
# Use smaller model for faster tests
config.transformer.hidden_dim = 128
config.transformer.num_layers = 2
config.transformer.num_heads = 4
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
"""Test forward pass (training mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.train()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Test forward pass
loss, _ = policy.forward(batch)
assert loss is not None
assert loss.item() is not None
assert loss.shape == ()
# Test backward pass
loss.backward()
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
"""Test select_action (inference mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
policy.reset() # Reset queues before inference
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_diffusion_objective():
"""Test policy with diffusion objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = DiffusionConfig(
noise_scheduler_type="DDPM",
num_train_timesteps=100,
num_inference_steps=10,
)
policy = MultiTaskDiTPolicy(config=config)
policy.train()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Test forward pass
loss, _ = policy.forward(batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_flow_matching_objective():
"""Test policy with flow matching objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
config.objective = FlowMatchingConfig(
sigma_min=0.0,
num_integration_steps=10, # Use fewer steps for faster tests
integration_method="euler",
)
policy = MultiTaskDiTPolicy(config=config)
policy.train()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Test forward pass
loss, _ = policy.forward(batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded correctly."""
root = tmp_path / "test_multi_task_dit_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.eval()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Move batch to the same device as the policy
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
loss, _ = policy.forward(batch)
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device
for key in observation_batch:
if isinstance(observation_batch[key], torch.Tensor):
observation_batch[key] = observation_batch[key].to(device)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_loss, _ = loaded_policy.forward(batch)
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Move observation batch to device
for key in loaded_observation_batch:
if isinstance(loaded_observation_batch[key], torch.Tensor):
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
# Compare state dicts
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
assert torch.allclose(loss, loaded_loss)
assert torch.allclose(actions, loaded_actions)
def test_multi_task_dit_policy_get_optim_params():
"""Test that the policy returns correct optimizer parameter groups."""
config = create_config(
state_dim=10,
action_dim=10,
n_obs_steps=2,
horizon=16,
n_action_steps=8,
)
policy = MultiTaskDiTPolicy(config=config)
param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder
assert len(param_groups) == 2
# First group is non-vision params (no lr specified, will use default)
assert "params" in param_groups[0]
assert len(param_groups[0]["params"]) > 0
# Second group is vision encoder params with different lr
assert "params" in param_groups[1]
assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
assert param_groups[1]["lr"] == expected_lr