mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Add multitask diffusion transformer policy
Add multitask diffusion transformer policy
This commit is contained in:
@@ -118,6 +118,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
# Policies
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "timm>=1.0.20"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
|
||||
@@ -157,7 +157,7 @@ available_datasets = sorted(
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/policies`
|
||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
# lists all available robots from `lerobot/robots`
|
||||
available_robots = [
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
@@ -25,6 +26,7 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -79,6 +80,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "multi_task_dit":
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
|
||||
return MultiTaskDiTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
@@ -120,8 +125,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"reward_classifier", "smolvla".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -136,6 +141,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "multi_task_dit":
|
||||
return MultiTaskDiTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
@@ -274,6 +281,16 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MultiTaskDiTConfig):
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_multi_task_dit_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
# Multi-Task DiT Policy
|
||||
|
||||
For details describing the architecture, see the citations and the blog post from Bryson Jones: https://brysonkjones.substack.com/p/dissecting-multi-task-diffusion-transformer-policy
|
||||
|
||||
## Trainining and Inference Baseline Recommendations:
|
||||
|
||||
### Training
|
||||
|
||||
- Number of demonstrations: >100 per task
|
||||
- Batch Size: 320
|
||||
- Objective: Diffusion
|
||||
- Cameras: At least two, with one egocentric view per arm
|
||||
|
||||
### Inference
|
||||
|
||||
- GPU: 5070 Ti or above in performance
|
||||
- Sampling:
|
||||
- Strategy: DDIM
|
||||
- Number of Timesteps: 10
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,433 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectiveConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for model objectives (diffusion, flow matching, etc.)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ObjectiveConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig(ObjectiveConfig):
|
||||
"""Configuration for standard diffusion model training and inference.
|
||||
|
||||
These parameters control the noise scheduling and denoising process for
|
||||
standard DDPM/DDIM diffusion models.
|
||||
"""
|
||||
|
||||
objective_name: str = field(default="diffusion", init=False)
|
||||
|
||||
# Noise scheduler configuration - controls diffusion process
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # 100 noise levels for fine-grained control
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Cosine schedule prevents extreme noise
|
||||
beta_start: float = 0.0001 # Small initial noise level
|
||||
beta_end: float = 0.02 # Moderate final noise level
|
||||
prediction_type: str = "epsilon" # Predict noise (works better than direct prediction)
|
||||
clip_sample: bool = True # Prevent extreme action values
|
||||
clip_sample_range: float = 1.0 # Clip to [-1, 1] range
|
||||
|
||||
# Inference configuration
|
||||
num_inference_steps: int | None = None # Default to num_train_timesteps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate diffusion-specific parameters."""
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(
|
||||
"beta values must satisfy 0 <= beta_start <= beta_end <= 1, "
|
||||
f"got {self.beta_start}, {self.beta_end}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimestepSamplingConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for timestep sampling strategies during training."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@TimestepSamplingConfig.register_subclass("uniform")
|
||||
@dataclass
|
||||
class UniformTimestepSamplingConfig(TimestepSamplingConfig):
|
||||
"""Uniform timestep sampling from [0, 1]."""
|
||||
|
||||
strategy_name: str = field(default="uniform", init=False)
|
||||
|
||||
|
||||
@TimestepSamplingConfig.register_subclass("beta")
|
||||
@dataclass
|
||||
class BetaTimestepSamplingConfig(TimestepSamplingConfig):
|
||||
"""Beta distribution timestep sampling.
|
||||
|
||||
Samples from Beta distribution emphasizing low timesteps (high noise).
|
||||
|
||||
This was inspired on the work from Physical Intelligence PI-0 model,
|
||||
where they suggested the beta distribution for sampling timesteps
|
||||
during training improved sample quality.
|
||||
"""
|
||||
|
||||
strategy_name: str = field(default="beta", init=False)
|
||||
|
||||
s: float = 0.999 # Max timestep threshold for beta sampling
|
||||
alpha: float = 1.5 # Beta distribution alpha parameter
|
||||
beta: float = 1.0 # Beta distribution beta parameter
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0.0 < self.s <= 1.0):
|
||||
raise ValueError(f"s must be in (0, 1], got {self.s}")
|
||||
|
||||
if self.alpha <= 0:
|
||||
raise ValueError(f"alpha must be positive, got {self.alpha}")
|
||||
|
||||
if self.beta <= 0:
|
||||
raise ValueError(f"beta must be positive, got {self.beta}")
|
||||
|
||||
|
||||
@ObjectiveConfig.register_subclass("flow_matching")
|
||||
@dataclass
|
||||
class FlowMatchingConfig(ObjectiveConfig):
|
||||
"""Configuration for flow matching training and inference.
|
||||
|
||||
These parameters control the velocity field learning and ODE integration
|
||||
process for flow matching models.
|
||||
"""
|
||||
|
||||
objective_name: str = field(default="flow_matching", init=False)
|
||||
|
||||
# Flow path construction
|
||||
sigma_min: float = 0.0 # Minimum noise level in flow interpolation path
|
||||
|
||||
# ODE integration for inference
|
||||
num_integration_steps: int = (
|
||||
100 # Number of ODE integration steps (increased from 50 for smoother trajectories)
|
||||
)
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
|
||||
# Timestep sampling strategy for training
|
||||
# Beta distribution found to be the most effective in practice, so it is the default
|
||||
timestep_sampling: TimestepSamplingConfig = field(default_factory=BetaTimestepSamplingConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(f"integration_method must be 'euler' or 'rk4', got {self.integration_method}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""Configuration for Transformer-based prediction model.
|
||||
|
||||
These parameters control the transformer architecture used for noise/velocity
|
||||
prediction in diffusion and flow matching models.
|
||||
"""
|
||||
|
||||
# Transformer architecture parameters
|
||||
hidden_dim: int = 512 # Hidden dimension of transformer
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = True # Whether to use positional encoding
|
||||
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate Transformer-specific parameters."""
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
if self.diffusion_step_embed_dim <= 0:
|
||||
raise ValueError("diffusion_step_embed_dim must be positive")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionEncoderConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for vision encoders.
|
||||
|
||||
All image preprocessing is centralized here:
|
||||
1. Resize (optional) - resize images to target resolution
|
||||
2. Crop (optional) - crop after resize, must be smaller than resize_shape
|
||||
3. Random crop - whether to use random cropping during training
|
||||
"""
|
||||
|
||||
use_separate_encoder_per_camera: bool = False # Common parameters across all vision encoders
|
||||
|
||||
# Learning rate multiplier for vision encoder parameters
|
||||
# Vision encoder learning rate = optimizer_lr * lr_multiplier
|
||||
lr_multiplier: float = 0.1
|
||||
|
||||
# Image preprocessing (centralized)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_shape: tuple[int, int] | None = (224, 224) # default input size for CLIP
|
||||
crop_is_random: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.resize_shape
|
||||
and self.crop_shape
|
||||
and (self.crop_shape[0] > self.resize_shape[0] or self.crop_shape[1] > self.resize_shape[1])
|
||||
):
|
||||
raise ValueError(
|
||||
f"crop_shape {self.crop_shape} must be smaller than or equal to "
|
||||
f"resize_shape {self.resize_shape}. Got crop={self.crop_shape}, resize={self.resize_shape}"
|
||||
)
|
||||
|
||||
|
||||
@VisionEncoderConfig.register_subclass("dinov3")
|
||||
@dataclass
|
||||
class DinoV3EncoderConfig(VisionEncoderConfig):
|
||||
"""DinoV3 vision encoder configuration.
|
||||
|
||||
DinoV3 is a self-supervised Vision Transformer trained by Meta.
|
||||
CLS token usage and spatial feature extraction are handled automatically.
|
||||
|
||||
Available backbones:
|
||||
- vit_base_patch16_dinov3.lvd1689m (768 dims)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_dinov3.lvd1689m"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate backbone name
|
||||
valid_backbones = [
|
||||
"vit_base_patch16_dinov3.lvd1689m",
|
||||
]
|
||||
if self.backbone not in valid_backbones:
|
||||
raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'")
|
||||
|
||||
|
||||
@VisionEncoderConfig.register_subclass("clip")
|
||||
@dataclass
|
||||
class CLIPVisionEncoderConfig(VisionEncoderConfig):
|
||||
"""CLIP vision encoder configuration.
|
||||
|
||||
CLIP is a vision-language model trained by OpenAI.
|
||||
CLS token usage is handled automatically.
|
||||
CLIP's internal preprocessing (resize to 224x224) can be overridden
|
||||
by setting resize_shape and crop_shape.
|
||||
|
||||
Available backbones:
|
||||
- vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224)
|
||||
"""
|
||||
|
||||
backbone: str = "vit_base_patch16_clip_224.openai"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Validate backbone name
|
||||
if "clip" not in self.backbone.lower():
|
||||
raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextEncoderConfig(draccus.ChoiceRegistry):
|
||||
"""Base configuration for text encoders.
|
||||
|
||||
If a text encoder is set in ObservationEncoderConfig, text conditioning
|
||||
is automatically enabled.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@TextEncoderConfig.register_subclass("clip")
|
||||
@dataclass
|
||||
class CLIPTextEncoderConfig(TextEncoderConfig):
|
||||
"""CLIP text encoder for task conditioning.
|
||||
|
||||
Uses CLIP's text encoder to embed task descriptions, which are then
|
||||
used to condition the policy. The text embeddings are processed by
|
||||
a learnable projection layer before being concatenated into the
|
||||
conditioning vector.
|
||||
"""
|
||||
|
||||
model: str = "openai/clip-vit-base-patch16"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if "clip" not in self.model.lower():
|
||||
raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationEncoderConfig:
|
||||
"""Top-level configuration for observation encoding.
|
||||
|
||||
This config combines:
|
||||
- Vision encoding (required): DinoV3 or CLIP vision encoder
|
||||
"""
|
||||
|
||||
vision: VisionEncoderConfig = field(default_factory=CLIPVisionEncoderConfig)
|
||||
text: TextEncoderConfig = field(default_factory=CLIPTextEncoderConfig)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
@dataclass
|
||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
"""
|
||||
|
||||
# Temporal structure - controls how the policy processes time and predicts actions
|
||||
n_obs_steps: int = 2 # num observations for temporal context (..., t-1, t)
|
||||
horizon: int = 100 # predicted action steps into the future
|
||||
n_action_steps: int = 24 # actions per policy call (receding horizon) -- ~0.8s is a good place to start
|
||||
|
||||
# Normalization strategy - critical for diffusion model performance
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD, # Standard ImageNet normalization for vision
|
||||
"STATE": NormalizationMode.MIN_MAX, # [-1,1] range for proper diffusion clipping
|
||||
"ACTION": NormalizationMode.MIN_MAX, # [-1,1] range required for diffusion process
|
||||
}
|
||||
)
|
||||
|
||||
drop_n_last_frames: int | None = None # Auto-calculated: horizon - n_action_steps - n_obs_steps + 1
|
||||
observation_encoder: ObservationEncoderConfig = field(default_factory=ObservationEncoderConfig)
|
||||
transformer: TransformerConfig = field(default_factory=TransformerConfig)
|
||||
objective: ObjectiveConfig = field(default_factory=DiffusionConfig)
|
||||
do_mask_loss_for_padding: bool = False # same logic as is implemented in LeRobot DP implementation
|
||||
|
||||
# training optimizer and scheduler hyperparameters
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0 # No weight decay is suggested to be optimal
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 0 # No warmup found to be optimal
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.drop_n_last_frames is None:
|
||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
"""Return Adam optimizer configuration optimized for diffusion training.
|
||||
|
||||
Note: Vision encoder learning rate is set separately via get_optim_params.
|
||||
"""
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
"""Return learning rate scheduler configuration."""
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# Robot state is always present via self.robot_state_feature, so we don't need to enforce images/env_state
|
||||
# This allows for testing and simple state-only policies
|
||||
|
||||
# Validate crop shape fits within image dimensions
|
||||
crop_shape = self.observation_encoder.vision.crop_shape
|
||||
if crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if crop_shape[0] > image_ft.shape[1] or crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Ensure all images have same shape (current limitation)
|
||||
if len(self.image_features) > 0:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def model_objective(self) -> str:
|
||||
return self.objective.objective_name
|
||||
|
||||
@property
|
||||
def is_diffusion(self) -> bool:
|
||||
return isinstance(self.objective, DiffusionConfig)
|
||||
|
||||
@property
|
||||
def is_flow_matching(self) -> bool:
|
||||
return isinstance(self.objective, FlowMatchingConfig)
|
||||
|
||||
def get_objective_config(self) -> DiffusionConfig | FlowMatchingConfig:
|
||||
"""Get the objective-specific configuration with proper typing."""
|
||||
return self.objective
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
"""Delta indices for stacking observations. Provides temporal context."""
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
"""Delta indices for action horizon prediction."""
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Indices for reward deltas (not used in diffusion policy)."""
|
||||
return None
|
||||
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-Task Diffusion Transformer (DiT) Policy
|
||||
|
||||
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
|
||||
Supports both diffusion and flow matching objectives for action generation.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.multi_task_dit.modules.objectives import DiffusionObjective, FlowMatchingObjective
|
||||
from lerobot.policies.multi_task_dit.modules.observation_encoder import ObservationEncoder
|
||||
from lerobot.policies.multi_task_dit.modules.transformer import DiffusionTransformer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
config_class = MultiTaskDiTConfig
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.observation_encoder = ObservationEncoder(config)
|
||||
conditioning_dim = self.observation_encoder.conditioning_dim
|
||||
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
|
||||
|
||||
action_dim = config.action_feature.shape[0]
|
||||
horizon = config.horizon
|
||||
|
||||
self.model_objective = config.model_objective
|
||||
if config.is_diffusion:
|
||||
self.objective = DiffusionObjective(
|
||||
config.get_objective_config(),
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
elif config.is_flow_matching:
|
||||
self.objective = FlowMatchingObjective(
|
||||
config.get_objective_config(),
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_objective: {self.model_objective}")
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> list:
|
||||
"""Returns parameter groups with different learning rates for vision vs non-vision parameters."""
|
||||
non_vision_params = []
|
||||
vision_encoder_params = []
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if "observation_encoder.vision_encoder" in name:
|
||||
vision_encoder_params.append(param)
|
||||
else:
|
||||
non_vision_params.append(param)
|
||||
|
||||
return [
|
||||
{"params": non_vision_params},
|
||||
{
|
||||
"params": vision_encoder_params,
|
||||
"lr": self.config.optimizer_lr * self.config.observation_encoder.vision.lr_multiplier,
|
||||
},
|
||||
]
|
||||
|
||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
|
||||
|
||||
start_idx = n_obs_steps - 1
|
||||
end_idx = start_idx + self.config.n_action_steps
|
||||
return actions[:, start_idx:end_idx]
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`."""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
if self.config.observation_encoder.text:
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||
|
||||
return loss, None
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
original_batch_keys = set(batch.keys())
|
||||
new_batch = {}
|
||||
for k in self._queues:
|
||||
if k in original_batch_keys:
|
||||
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
|
||||
# for task description which is a list of strings
|
||||
new_batch[k] = self._queues[k][-1]
|
||||
else:
|
||||
queue_values = list(self._queues[k])
|
||||
new_batch[k] = torch.stack(queue_values, dim=1)
|
||||
batch = new_batch
|
||||
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method manages caching of observations and actions by generating an action chunk
|
||||
and returning actions from the cache until it's depleted.
|
||||
"""
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module contains a base objective class, and implementation of the objective
|
||||
classes for use in the Multi-Task Diffusion Transformer Policy.
|
||||
|
||||
Architecture:
|
||||
- BaseObjective: Abstract interface definition
|
||||
- DiffusionObjective: Implements standard DDPM/DDIM diffusion objective
|
||||
- FlowMatchingObjective: Implements flow matching objective
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class BaseObjective(ABC):
|
||||
"""
|
||||
Base class for objectives used in Multi-Task DiT policy.
|
||||
Defines the interface for training loss computation and conditional sampling.
|
||||
"""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int):
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
"""Compute training loss for the objective.
|
||||
|
||||
Args:
|
||||
model: The prediction network
|
||||
batch: Training batch with observations and actions
|
||||
conditioning_vec: Encoded observation features for conditioning
|
||||
|
||||
Returns:
|
||||
Scalar loss tensor
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
"""Generate action samples conditioned on embedded observation features.
|
||||
|
||||
Args:
|
||||
model: The prediction network
|
||||
batch_size: The number of samples to generate
|
||||
conditioning_vec: Encoded observation features for conditioning
|
||||
|
||||
Returns:
|
||||
Generated action sequences (batch_size, horizon, action_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DiffusionObjective(BaseObjective):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation.
|
||||
|
||||
Contains the noise scheduler, training loss, and conditional sampling.
|
||||
"""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
# Build noise scheduler
|
||||
scheduler_kwargs = {
|
||||
"num_train_timesteps": config.num_train_timesteps,
|
||||
"beta_start": config.beta_start,
|
||||
"beta_end": config.beta_end,
|
||||
"beta_schedule": config.beta_schedule,
|
||||
"clip_sample": getattr(config, "clip_sample", True),
|
||||
"clip_sample_range": getattr(config, "clip_sample_range", 1.0),
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
|
||||
|
||||
# Inference steps default to training steps if not provided
|
||||
self.num_inference_steps = (
|
||||
config.num_inference_steps
|
||||
if getattr(config, "num_inference_steps", None) is not None
|
||||
else self.noise_scheduler.config.num_train_timesteps
|
||||
)
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
clean_actions = batch["action"]
|
||||
noise = torch.randn_like(clean_actions)
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(clean_actions.shape[0],),
|
||||
device=clean_actions.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
|
||||
|
||||
# Target depends on prediction type
|
||||
prediction_type = self.noise_scheduler.config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif prediction_type == "sample":
|
||||
target = clean_actions
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {prediction_type}")
|
||||
|
||||
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"] # (B, T)
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.horizon, self.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
model_output = model(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
conditioning_vec=conditioning_vec,
|
||||
)
|
||||
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlowMatchingObjective(BaseObjective):
|
||||
"""
|
||||
Flow matching objective: trains a model to predict velocity fields v_θ(x, t) that transports
|
||||
noise to data. This basically interpolates as path between noise and a trained distribution
|
||||
with a set target velocity.
|
||||
"""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
"""Sample timesteps according to configured strategy.
|
||||
|
||||
Uniform: Sample t uniformly from [0,1]
|
||||
Beta: Sample t from Beta(α,β) scaled to [0,s], emphasizing high noise (low t)
|
||||
"""
|
||||
if self.config.timestep_sampling.strategy_name == "uniform":
|
||||
return torch.rand(batch_size, device=device)
|
||||
elif self.config.timestep_sampling.strategy_name == "beta":
|
||||
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
|
||||
# This emphasizes t near 0 (high noise) when α > β
|
||||
beta_dist = torch.distributions.Beta(
|
||||
self.config.timestep_sampling.alpha, self.config.timestep_sampling.beta
|
||||
)
|
||||
u = beta_dist.sample((batch_size,)).to(device)
|
||||
return self.config.timestep_sampling.s * (1.0 - u)
|
||||
else:
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling.strategy_name}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
"""Compute flow matching training loss.
|
||||
|
||||
Trains the model to predict the velocity field along linear interpolation paths.
|
||||
"""
|
||||
data = batch["action"] # Clean action sequences (B, T, D)
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
|
||||
noise = torch.randn_like(data)
|
||||
t = self._sample_timesteps(batch_size, device)
|
||||
t_expanded = t.view(-1, 1, 1) # (B, 1, 1) for broadcasting
|
||||
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
|
||||
|
||||
# The velocity we want the model to learn: v = data - (1-σ)·noise
|
||||
target_velocity = data - (1 - self.config.sigma_min) * noise
|
||||
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
# Optionally mask padded actions
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"] # (B, T)
|
||||
loss = loss * valid_mask.unsqueeze(-1) # (B, T, D)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
"""Generate actions by integrating the learned velocity field via ODE.
|
||||
|
||||
Solves: dx/dt = v_θ(x,t) from t=0 (noise) to t=1 (data)
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
# Start from random noise at t=0
|
||||
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
|
||||
|
||||
# Time grid from 0 to 1
|
||||
num_steps = self.config.num_integration_steps
|
||||
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
|
||||
|
||||
# Integrate ODE using chosen method
|
||||
if self.config.integration_method == "euler":
|
||||
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
|
||||
elif self.config.integration_method == "rk4":
|
||||
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
|
||||
else:
|
||||
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
|
||||
|
||||
return x
|
||||
|
||||
def _euler_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
"""
|
||||
Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)
|
||||
"""
|
||||
x = x_init
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
t_scalar = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
|
||||
# Create time tensor for batch
|
||||
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
|
||||
|
||||
# Get velocity at current point
|
||||
with torch.no_grad():
|
||||
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
# Euler step
|
||||
x = x + dt * velocity
|
||||
|
||||
return x
|
||||
|
||||
def _rk4_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
"""4th-order Runge-Kutta integration.
|
||||
|
||||
Uses 4 velocity evaluations per step:
|
||||
k1 = v(x, t)
|
||||
k2 = v(x + dt·k1/2, t + dt/2)
|
||||
k3 = v(x + dt·k2/2, t + dt/2)
|
||||
k4 = v(x + dt·k3, t + dt)
|
||||
x_next = x + dt/6·(k1 + 2k2 + 2k3 + k4)
|
||||
|
||||
4x slower than Euler but more accurate.
|
||||
|
||||
Note: In practice, this seems to not matter much
|
||||
"""
|
||||
x = x_init
|
||||
|
||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||
"""dynamics helper to get velocity at (x, t)"""
|
||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||
with torch.no_grad():
|
||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
t = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
|
||||
# RK4 stages
|
||||
k1 = dynamics(x, t)
|
||||
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
|
||||
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
|
||||
k4 = dynamics(x + dt * k3, t + dt)
|
||||
|
||||
# Weighted combination
|
||||
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,396 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Observation encoding for Multi-Task DiT policy.
|
||||
|
||||
Handles vision encoding, text encoding, robot state, and environment state.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import einops
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class BaseVisionEncoder(ABC):
|
||||
"""Abstract base class for vision encoders."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_output_shape(self) -> tuple:
|
||||
"""Get the output shape (C', H', W')."""
|
||||
pass
|
||||
|
||||
|
||||
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
|
||||
"""DinoV3 vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
)
|
||||
|
||||
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
|
||||
self.embed_dim = self.model.embed_dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
||||
return cls_features
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0, # Remove classification head, we want features
|
||||
)
|
||||
|
||||
# CLIP models have 1 CLS token (no register tokens like DinoV3)
|
||||
self.num_non_spatial_tokens = 1
|
||||
|
||||
# Get embed_dim from model config
|
||||
self.embed_dim = self.model.embed_dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token.
|
||||
|
||||
Preprocessing (resize, crop) is handled by ObservationEncoder
|
||||
"""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
||||
return cls_features
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
backbone_name = config.backbone.lower()
|
||||
|
||||
# Check if it's a CLIP model
|
||||
if "clip" in backbone_name:
|
||||
return CLIPEncoder(config)
|
||||
|
||||
# Check if it's a DinoV3 model
|
||||
elif "dinov3" in backbone_name:
|
||||
return DinoV3Encoder(config)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vision backbone: {config.backbone}. "
|
||||
f"Currently supported: DinoV3 models and CLIP models"
|
||||
)
|
||||
|
||||
|
||||
# Registry for easy extension
|
||||
VISION_ENCODER_REGISTRY: dict[str, type] = {
|
||||
"dinov3": DinoV3Encoder,
|
||||
"clip": CLIPEncoder,
|
||||
}
|
||||
|
||||
|
||||
def register_vision_encoder(name: str, encoder_class: type):
|
||||
"""Register a new vision encoder type.
|
||||
|
||||
Args:
|
||||
name: Identifier for the encoder type
|
||||
encoder_class: Class implementing BaseVisionEncoder interface
|
||||
"""
|
||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
||||
|
||||
|
||||
def get_registered_encoders() -> dict[str, type]:
|
||||
"""Get all registered vision encoder types.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping encoder names to classes
|
||||
"""
|
||||
return VISION_ENCODER_REGISTRY.copy()
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and learnable projection."""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
# Load CLIP text encoder and tokenizer
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||
|
||||
# Freeze all CLIP text encoder parameters
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||
|
||||
# Learnable projection layer (always present, only trainable component)
|
||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||
|
||||
def forward(self, text: str | list[str]) -> Tensor:
|
||||
"""Encode text to feature vectors.
|
||||
|
||||
Args:
|
||||
text: Single string or list of strings
|
||||
|
||||
Returns:
|
||||
Text features of shape (B, projection_dim)
|
||||
"""
|
||||
# handle single string input
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||||
|
||||
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
|
||||
|
||||
# encode text through CLIP (frozen)
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(**text_inputs)
|
||||
# Extract pooled output (EOS token embedding)
|
||||
clip_features = outputs.pooler_output # (B, text_embed_dim)
|
||||
|
||||
# project to desired dimension (trainable)
|
||||
projected_features = self.projection(clip_features) # (B, projection_dim)
|
||||
|
||||
return projected_features
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""Handles all observation processing for the conditioning vector."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
vision_config = config.observation_encoder.vision
|
||||
|
||||
self._setup_preprocessing(vision_config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys()) # Preserve ordering
|
||||
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[create_vision_encoder(vision_config) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = create_vision_encoder(vision_config)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
self.vision_encoders = None
|
||||
self.camera_names = []
|
||||
self.num_cameras = 0
|
||||
|
||||
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
|
||||
self.robot_state_dim = config.robot_state_feature.shape[0]
|
||||
else:
|
||||
self.robot_state_dim = 0
|
||||
|
||||
if hasattr(config, "env_state_feature") and config.env_state_feature:
|
||||
self.env_state_dim = config.env_state_feature.shape[0]
|
||||
else:
|
||||
self.env_state_dim = 0
|
||||
|
||||
text_config = config.observation_encoder.text
|
||||
self.text_dim = config.transformer.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
def _apply_preprocessing(self, images: Tensor) -> Tensor:
|
||||
"""Apply preprocessing transforms to images."""
|
||||
if self.do_resize:
|
||||
images = self.resize(images)
|
||||
if self.do_crop:
|
||||
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
|
||||
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, vision_config):
|
||||
"""Setup image preprocessing transforms."""
|
||||
if vision_config.resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=vision_config.resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
if vision_config.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
|
||||
if vision_config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
"""Setup for vector output."""
|
||||
total_dim = 0
|
||||
|
||||
# Vision features - get CLS token feature dimension
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
|
||||
|
||||
# Get output shape from encoder (deterministic for CLS tokens)
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
c, h, w = feature_map_shape
|
||||
spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim
|
||||
|
||||
total_dim += spatial_feature_dim * self.num_cameras
|
||||
|
||||
# State features
|
||||
total_dim += self.robot_state_dim
|
||||
total_dim += self.env_state_dim
|
||||
|
||||
# Text features
|
||||
total_dim += self.text_dim
|
||||
|
||||
# Account for temporal stacking
|
||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||
|
||||
def encode(self, batch: dict) -> Tensor:
|
||||
"""Encode observations to vector format."""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
conditioning_feats = []
|
||||
|
||||
conditioning_feats.append(batch[OBS_STATE])
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W)
|
||||
|
||||
# Handle case when n_obs=1 and time dimension might be squeezed
|
||||
if len(images.shape) == 5:
|
||||
# Shape is (B, N, C, H, W) - add time dimension
|
||||
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
||||
|
||||
vision_config = self.config.observation_encoder.vision
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
# Process each camera with its own encoder
|
||||
camera_features = []
|
||||
|
||||
for cam_idx in range(self.num_cameras):
|
||||
# Extract images for this camera: (B, n_obs_steps, C, H, W)
|
||||
cam_images = images[:, :, cam_idx]
|
||||
|
||||
# Rearrange to: (B*n_obs_steps, C, H, W)
|
||||
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
|
||||
|
||||
# Apply preprocessing
|
||||
cam_images_flat = self._apply_preprocessing(cam_images_flat)
|
||||
|
||||
# Process with camera-specific encoder (direct index access)
|
||||
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
|
||||
|
||||
# Apply spatial vectorization (flatten CLS token features)
|
||||
cam_visual_features = cam_features.flatten(start_dim=1)
|
||||
|
||||
# Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim)
|
||||
cam_features_reshaped = einops.rearrange(
|
||||
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
camera_features.append(cam_features_reshaped)
|
||||
|
||||
# Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim)
|
||||
img_features = torch.cat(camera_features, dim=-1)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
else:
|
||||
# Shared encoder for all cameras
|
||||
# Rearrange to: (B*n_obs_steps*num_cameras, C, H, W)
|
||||
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
|
||||
|
||||
images_flat = self._apply_preprocessing(images_flat)
|
||||
|
||||
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
|
||||
|
||||
# Reshape back and concatenate camera features
|
||||
# (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim)
|
||||
img_features = einops.rearrange(
|
||||
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
|
||||
)
|
||||
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
|
||||
conditioning_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
if self.text_encoder is not None and "task" in batch:
|
||||
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
||||
# Expand across temporal dimension to match other features
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
||||
print("Text features shape after unsqueeze and expand:", text_features.shape)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
for vec in conditioning_feats:
|
||||
print(f"Conditioning feature shape: {vec.shape}")
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
||||
|
||||
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|
||||
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformer backbone for noise prediction in Multi-Task DiT policy.
|
||||
|
||||
Adapted from DiT (Diffusion Transformer: https://github.com/facebookresearch/DiT) for 1D trajectory data.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
||||
"""Modulate input with shift and scale for AdaLN-Zero.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
shift: Shift parameter
|
||||
scale: Scale parameter
|
||||
|
||||
Returns:
|
||||
Modulated tensor: x * (1 + scale) + shift
|
||||
"""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embeddings for timesteps.
|
||||
|
||||
Identical to the reference implementation - generates smooth embeddings
|
||||
for diffusion timestep values.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
"""
|
||||
Args:
|
||||
dim: Embedding dimension
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B,) tensor of timestep values
|
||||
|
||||
Returns:
|
||||
(B, dim) positional embeddings
|
||||
"""
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""DiT-style transformer block with AdaLN-Zero.
|
||||
|
||||
Official DiT implementation with 6-parameter adaptive layer normalization:
|
||||
- shift_msa, scale_msa, gate_msa: for attention block
|
||||
- shift_mlp, scale_mlp, gate_mlp: for MLP block
|
||||
|
||||
Reference: https://github.com/facebookresearch/DiT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Hidden dimension of transformer
|
||||
num_heads: Number of attention heads
|
||||
num_features: Size of conditioning features
|
||||
dropout: Dropout rate
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
|
||||
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
# Feed-forward network (MLP)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size * 4),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size * 4, hidden_size),
|
||||
)
|
||||
|
||||
# AdaLN-Zero modulation: produces 6 parameters (shift, scale, gate for attn and mlp)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, T, hidden_size) input sequence
|
||||
features: (B, num_features) conditioning features
|
||||
|
||||
Returns:
|
||||
(B, T, hidden_size) processed sequence
|
||||
"""
|
||||
# Generate 6 modulation parameters from conditioning
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
features
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# Attention block: norm → modulate → attn → gate × output → residual
|
||||
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
|
||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||
|
||||
# MLP block: norm → modulate → mlp → gate × output → residual
|
||||
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
|
||||
mlp_out = self.mlp(mlp_input)
|
||||
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
"""
|
||||
Transformer-based diffusion noise prediction model.
|
||||
"""
|
||||
|
||||
def __init__(self, config, conditioning_dim: int):
|
||||
"""Initialize transformer for noise prediction.
|
||||
|
||||
Args:
|
||||
config: Multi-Task DiTConfig with transformer parameters
|
||||
conditioning_dim: Dimension of concatenated observation features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.transformer_config = config.transformer
|
||||
self.conditioning_dim = conditioning_dim
|
||||
|
||||
self.action_dim = config.action_feature.shape[0]
|
||||
self.horizon = config.horizon
|
||||
self.hidden_size = self.transformer_config.hidden_dim
|
||||
self.num_layers = self.transformer_config.num_layers
|
||||
self.num_heads = self.transformer_config.num_heads
|
||||
self.dropout = self.transformer_config.dropout
|
||||
|
||||
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.cond_dim = self.timestep_embed_dim + conditioning_dim
|
||||
|
||||
# Project action dimensions to hidden size
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if self.transformer_config.use_positional_encoding:
|
||||
# Learnable positional embeddings for sequence positions
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Project back to action dimensions
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
|
||||
|
||||
# Zero-initialize adaLN_modulation layers for AdaLN-Zero
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
Zero-initializing the final linear layer of adaLN_modulation in each block improves training stability
|
||||
"""
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
|
||||
"""Predict noise to remove from noisy actions.
|
||||
|
||||
Args:
|
||||
x: (B, T, action_dim) noisy action sequences
|
||||
timestep: (B,) diffusion timesteps
|
||||
conditioning_vec: (B, conditioning_dim) observation features (required)
|
||||
|
||||
Returns:
|
||||
(B, T, action_dim) predicted noise
|
||||
"""
|
||||
_, seq_len, _ = x.shape
|
||||
|
||||
timestep_features = self.time_mlp(timestep) # (B, timestep_embed_dim)
|
||||
|
||||
# conditioning_vec is now required
|
||||
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) # (B, cond_dim)
|
||||
|
||||
# Project action sequence to hidden dimension
|
||||
hidden_seq = self.input_proj(x) # (B, T, hidden_size)
|
||||
|
||||
if self.pos_embedding is not None:
|
||||
# Add learned positional embeddings
|
||||
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] # (B, T, hidden_size)
|
||||
|
||||
# Pass through transformer layers with conditioning
|
||||
for block in self.transformer_blocks:
|
||||
hidden_seq = block(hidden_seq, cond_features) # (B, T, hidden_size)
|
||||
|
||||
# Project back to action space
|
||||
output = self.output_proj(hidden_seq) # (B, T, action_dim)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_multi_task_dit_pre_post_processors(
|
||||
config: MultiTaskDiTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Creates pre- and post-processing pipelines for the Multi-Task DiT policy."""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,377 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for Multi-Task DiT policy.
|
||||
|
||||
To run tests with GPU on Modal (temporary script):
|
||||
modal run run_tests_modal.py
|
||||
|
||||
To run tests locally:
|
||||
python -m pytest tests/policies/test_multi_task_dit_policy.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import (
|
||||
DiffusionConfig,
|
||||
FlowMatchingConfig,
|
||||
MultiTaskDiTConfig,
|
||||
)
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 17
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def create_train_batch(
|
||||
batch_size: int = 2,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> dict[str, Tensor]:
|
||||
"""Create a training batch with visual input and text."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
|
||||
ACTION: torch.randn(batch_size, horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(
|
||||
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
|
||||
) -> dict:
|
||||
"""Create observation batch for inference for a single timestep."""
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
|
||||
"task": ["pick up the red cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def create_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 10,
|
||||
n_obs_steps: int = 2,
|
||||
horizon: int = 16,
|
||||
n_action_steps: int = 8,
|
||||
with_visual: bool = True,
|
||||
height: int = 224,
|
||||
width: int = 224,
|
||||
) -> MultiTaskDiTConfig:
|
||||
"""Create a MultiTaskDiT config for testing.
|
||||
|
||||
Args:
|
||||
state_dim: Dimension of state observations
|
||||
action_dim: Dimension of actions
|
||||
n_obs_steps: Number of observation steps
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
with_visual: Whether to include visual input (default: True)
|
||||
height: Image height (only used if with_visual=True)
|
||||
width: Image width (only used if with_visual=True)
|
||||
"""
|
||||
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
|
||||
|
||||
if with_visual:
|
||||
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, height, width)
|
||||
)
|
||||
|
||||
config = MultiTaskDiTConfig(
|
||||
input_features=input_features,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
# Use smaller model for faster tests
|
||||
config.transformer.hidden_dim = 128
|
||||
config.transformer.num_layers = 2
|
||||
config.transformer.num_heads = 4
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test forward pass (training mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
assert loss.shape == ()
|
||||
|
||||
# Test backward pass
|
||||
loss.backward()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
|
||||
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
"""Test select_action (inference mode)."""
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.reset() # Reset queues before inference
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_diffusion_objective():
|
||||
"""Test policy with diffusion objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = DiffusionConfig(
|
||||
noise_scheduler_type="DDPM",
|
||||
num_train_timesteps=100,
|
||||
num_inference_steps=10,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_flow_matching_objective():
|
||||
"""Test policy with flow matching objective."""
|
||||
batch_size = 2
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
config.objective = FlowMatchingConfig(
|
||||
sigma_min=0.0,
|
||||
num_integration_steps=10, # Use fewer steps for faster tests
|
||||
integration_method="euler",
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
loss, _ = policy.forward(batch)
|
||||
assert loss is not None
|
||||
assert loss.item() is not None
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded correctly."""
|
||||
root = tmp_path / "test_multi_task_dit_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
n_obs_steps = 2
|
||||
horizon = 16
|
||||
n_action_steps = 8
|
||||
|
||||
config = create_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
# Get device before saving
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||
|
||||
# Explicitly move loaded_policy to the same device
|
||||
loaded_policy.to(device)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_train_batch(
|
||||
batch_size=batch_size,
|
||||
n_obs_steps=n_obs_steps,
|
||||
horizon=horizon,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
# Move batch to the same device as the policy
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(batch)
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in observation_batch:
|
||||
if isinstance(observation_batch[key], torch.Tensor):
|
||||
observation_batch[key] = observation_batch[key].to(device)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_loss, _ = loaded_policy.forward(batch)
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
# Move observation batch to device
|
||||
for key in loaded_observation_batch:
|
||||
if isinstance(loaded_observation_batch[key], torch.Tensor):
|
||||
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
# Compare state dicts
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
assert torch.allclose(loss, loaded_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_multi_task_dit_policy_get_optim_params():
|
||||
"""Test that the policy returns correct optimizer parameter groups."""
|
||||
config = create_config(
|
||||
state_dim=10,
|
||||
action_dim=10,
|
||||
n_obs_steps=2,
|
||||
horizon=16,
|
||||
n_action_steps=8,
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
param_groups = policy.get_optim_params()
|
||||
|
||||
# Should have 2 parameter groups: non-vision and vision encoder
|
||||
assert len(param_groups) == 2
|
||||
|
||||
# First group is non-vision params (no lr specified, will use default)
|
||||
assert "params" in param_groups[0]
|
||||
assert len(param_groups[0]["params"]) > 0
|
||||
|
||||
# Second group is vision encoder params with different lr
|
||||
assert "params" in param_groups[1]
|
||||
assert "lr" in param_groups[1]
|
||||
expected_lr = config.optimizer_lr * config.observation_encoder.vision.lr_multiplier
|
||||
assert param_groups[1]["lr"] == expected_lr
|
||||
Reference in New Issue
Block a user