mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs
This commit is contained in:
@@ -167,9 +167,13 @@ class TransformerConfig:
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = True # Whether to use positional encoding
|
||||
use_positional_encoding: bool = False # Whether to use absolute positional encoding
|
||||
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
||||
|
||||
# RoPE (Rotary Position Embedding) configuration
|
||||
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
|
||||
rope_base: float = 10000.0 # Base frequency for RoPE computation
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate Transformer-specific parameters."""
|
||||
if self.hidden_dim <= 0:
|
||||
|
||||
@@ -71,6 +71,146 @@ class SinusoidalPosEmb(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE) for transformers.
|
||||
|
||||
RoPE encodes position information by rotating query and key vectors,
|
||||
which naturally captures relative positions through the dot product.
|
||||
Applied at every attention layer rather than once at input.
|
||||
|
||||
To do this, we need to reimplement the attention mechanism to apply RoPE
|
||||
to Q and K before computing the attention scores, so we cannot use the
|
||||
the built-in MultiheadAttention module.
|
||||
|
||||
Original RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer)
|
||||
"""
|
||||
|
||||
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
|
||||
super().__init__()
|
||||
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.base = base
|
||||
|
||||
# Precompute inverse frequencies: theta_i = 1 / (base^(2i/d))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self._precompute_cache(max_seq_len)
|
||||
|
||||
def _precompute_cache(self, seq_len: int):
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def _rotate_half(self, x: Tensor) -> Tensor:
|
||||
"""Rotate half the hidden dims of the input.
|
||||
|
||||
For x = [x1, x2], returns [-x2, x1]
|
||||
"""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Apply rotary embeddings to query and key tensors."""
|
||||
seq_len = q.shape[2]
|
||||
|
||||
if seq_len > self.max_seq_len:
|
||||
raise ValueError(
|
||||
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}. "
|
||||
f"Increase max_seq_len in RoPE config."
|
||||
)
|
||||
|
||||
# Slice precomputed cache to actual sequence length
|
||||
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
|
||||
# Apply rotation: q_rot = q * cos + rotate_half(q) * sin
|
||||
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
|
||||
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
|
||||
|
||||
return q_rotated, k_rotated
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
"""Multi-head self-attention with Rotary Position Embedding (RoPE).
|
||||
|
||||
Custom attention implementation that applies RoPE to Q and K before
|
||||
computing attention scores. This allows position information to be
|
||||
encoded at every attention layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Total hidden dimension
|
||||
num_heads: Number of attention heads
|
||||
dropout: Attention dropout rate
|
||||
max_seq_len: Maximum sequence length for RoPE cache
|
||||
rope_base: Base for RoPE frequency computation
|
||||
"""
|
||||
super().__init__()
|
||||
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, T, hidden_size) input sequence
|
||||
|
||||
Returns:
|
||||
(B, T, hidden_size) attention output
|
||||
"""
|
||||
B, T, _ = x.shape # noqa: N806
|
||||
|
||||
# Compute Q, K, V
|
||||
qkv = self.qkv_proj(x) # (B, T, 3 * hidden_size)
|
||||
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, T, head_dim)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, num_heads, T, head_dim)
|
||||
|
||||
# Apply RoPE to Q and K
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
# Scaled dot-product attention
|
||||
# Using PyTorch's efficient attention when available
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
|
||||
) # (B, num_heads, T, head_dim)
|
||||
|
||||
# Reshape and project output
|
||||
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) # (B, T, hidden_size)
|
||||
output = self.out_proj(attn_out)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""DiT-style transformer block with AdaLN-Zero.
|
||||
|
||||
@@ -78,11 +218,20 @@ class TransformerBlock(nn.Module):
|
||||
- shift_msa, scale_msa, gate_msa: for attention block
|
||||
- shift_mlp, scale_mlp, gate_mlp: for MLP block
|
||||
|
||||
Supports both standard attention and RoPE attention.
|
||||
|
||||
Reference: https://github.com/facebookresearch/DiT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0
|
||||
self,
|
||||
hidden_size: int = 128,
|
||||
num_heads: int = 4,
|
||||
num_features: int = 128,
|
||||
dropout: float = 0.0,
|
||||
use_rope: bool = False,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -90,12 +239,26 @@ class TransformerBlock(nn.Module):
|
||||
num_heads: Number of attention heads
|
||||
num_features: Size of conditioning features
|
||||
dropout: Dropout rate
|
||||
use_rope: Whether to use Rotary Position Embedding
|
||||
max_seq_len: Maximum sequence length (for RoPE cache)
|
||||
rope_base: Base frequency for RoPE
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
self.use_rope = use_rope
|
||||
|
||||
if use_rope:
|
||||
self.attn = RoPEAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
max_seq_len=max_seq_len,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
else:
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
|
||||
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
@@ -128,7 +291,12 @@ class TransformerBlock(nn.Module):
|
||||
# Attention block: norm → modulate → attn → gate × output → residual
|
||||
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
|
||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
|
||||
if self.use_rope:
|
||||
attn_out = self.attn(attn_input)
|
||||
else:
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||
|
||||
# MLP block: norm → modulate → mlp → gate × output → residual
|
||||
@@ -163,6 +331,7 @@ class DiffusionTransformer(nn.Module):
|
||||
self.num_layers = self.transformer_config.num_layers
|
||||
self.num_heads = self.transformer_config.num_heads
|
||||
self.dropout = self.transformer_config.dropout
|
||||
self.use_rope = self.transformer_config.use_rope
|
||||
|
||||
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
@@ -179,7 +348,7 @@ class DiffusionTransformer(nn.Module):
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if self.transformer_config.use_positional_encoding:
|
||||
# Learnable positional embeddings for sequence positions
|
||||
# Learnable positional embeddings for sequence positions (absolute encoding)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
)
|
||||
@@ -193,6 +362,9 @@ class DiffusionTransformer(nn.Module):
|
||||
num_heads=self.num_heads,
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
use_rope=self.use_rope,
|
||||
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
|
||||
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user