add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs

This commit is contained in:
Bryson Jones
2025-12-09 08:42:56 -08:00
parent a0d5a088e3
commit 46ebcc2f7d
2 changed files with 183 additions and 7 deletions
@@ -167,9 +167,13 @@ class TransformerConfig:
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = True # Whether to use positional encoding
use_positional_encoding: bool = False # Whether to use absolute positional encoding
diffusion_step_embed_dim: int = 256 # Timestep embedding size
# RoPE (Rotary Position Embedding) configuration
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
rope_base: float = 10000.0 # Base frequency for RoPE computation
def __post_init__(self):
"""Validate Transformer-specific parameters."""
if self.hidden_dim <= 0:
@@ -71,6 +71,146 @@ class SinusoidalPosEmb(nn.Module):
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers.
RoPE encodes position information by rotating query and key vectors,
which naturally captures relative positions through the dot product.
Applied at every attention layer rather than once at input.
To do this, we need to reimplement the attention mechanism to apply RoPE
to Q and K before computing the attention scores, so we cannot use the
the built-in MultiheadAttention module.
Original RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer)
"""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute inverse frequencies: theta_i = 1 / (base^(2i/d))
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
"""Rotate half the hidden dims of the input.
For x = [x1, x2], returns [-x2, x1]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
"""Apply rotary embeddings to query and key tensors."""
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}. "
f"Increase max_seq_len in RoPE config."
)
# Slice precomputed cache to actual sequence length
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
# Apply rotation: q_rot = q * cos + rotate_half(q) * sin
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE).
Custom attention implementation that applies RoPE to Q and K before
computing attention scores. This allows position information to be
encoded at every attention layer.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
"""
Args:
hidden_size: Total hidden dimension
num_heads: Number of attention heads
dropout: Attention dropout rate
max_seq_len: Maximum sequence length for RoPE cache
rope_base: Base for RoPE frequency computation
"""
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, T, hidden_size) input sequence
Returns:
(B, T, hidden_size) attention output
"""
B, T, _ = x.shape # noqa: N806
# Compute Q, K, V
qkv = self.qkv_proj(x) # (B, T, 3 * hidden_size)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, num_heads, T, head_dim)
# Apply RoPE to Q and K
q, k = self.rope(q, k)
# Scaled dot-product attention
# Using PyTorch's efficient attention when available
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
) # (B, num_heads, T, head_dim)
# Reshape and project output
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) # (B, T, hidden_size)
output = self.out_proj(attn_out)
return output
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero.
@@ -78,11 +218,20 @@ class TransformerBlock(nn.Module):
- shift_msa, scale_msa, gate_msa: for attention block
- shift_mlp, scale_mlp, gate_mlp: for MLP block
Supports both standard attention and RoPE attention.
Reference: https://github.com/facebookresearch/DiT
"""
def __init__(
self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
"""
Args:
@@ -90,12 +239,26 @@ class TransformerBlock(nn.Module):
num_heads: Number of attention heads
num_features: Size of conditioning features
dropout: Dropout rate
use_rope: Whether to use Rotary Position Embedding
max_seq_len: Maximum sequence length (for RoPE cache)
rope_base: Base frequency for RoPE
"""
super().__init__()
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
@@ -128,7 +291,12 @@ class TransformerBlock(nn.Module):
# Attention block: norm → modulate → attn → gate × output → residual
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
# MLP block: norm → modulate → mlp → gate × output → residual
@@ -163,6 +331,7 @@ class DiffusionTransformer(nn.Module):
self.num_layers = self.transformer_config.num_layers
self.num_heads = self.transformer_config.num_heads
self.dropout = self.transformer_config.dropout
self.use_rope = self.transformer_config.use_rope
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
self.time_mlp = nn.Sequential(
@@ -179,7 +348,7 @@ class DiffusionTransformer(nn.Module):
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if self.transformer_config.use_positional_encoding:
# Learnable positional embeddings for sequence positions
# Learnable positional embeddings for sequence positions (absolute encoding)
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
@@ -193,6 +362,9 @@ class DiffusionTransformer(nn.Module):
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
)
for _ in range(self.num_layers)
]