mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs
This commit is contained in:
@@ -167,9 +167,13 @@ class TransformerConfig:
|
|||||||
num_layers: int = 6 # Number of transformer layers
|
num_layers: int = 6 # Number of transformer layers
|
||||||
num_heads: int = 8 # Number of attention heads
|
num_heads: int = 8 # Number of attention heads
|
||||||
dropout: float = 0.1 # Dropout rate
|
dropout: float = 0.1 # Dropout rate
|
||||||
use_positional_encoding: bool = True # Whether to use positional encoding
|
use_positional_encoding: bool = False # Whether to use absolute positional encoding
|
||||||
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
diffusion_step_embed_dim: int = 256 # Timestep embedding size
|
||||||
|
|
||||||
|
# RoPE (Rotary Position Embedding) configuration
|
||||||
|
use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True)
|
||||||
|
rope_base: float = 10000.0 # Base frequency for RoPE computation
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Validate Transformer-specific parameters."""
|
"""Validate Transformer-specific parameters."""
|
||||||
if self.hidden_dim <= 0:
|
if self.hidden_dim <= 0:
|
||||||
|
|||||||
@@ -71,6 +71,146 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding(nn.Module):
|
||||||
|
"""Rotary Position Embedding (RoPE) for transformers.
|
||||||
|
|
||||||
|
RoPE encodes position information by rotating query and key vectors,
|
||||||
|
which naturally captures relative positions through the dot product.
|
||||||
|
Applied at every attention layer rather than once at input.
|
||||||
|
|
||||||
|
To do this, we need to reimplement the attention mechanism to apply RoPE
|
||||||
|
to Q and K before computing the attention scores, so we cannot use the
|
||||||
|
the built-in MultiheadAttention module.
|
||||||
|
|
||||||
|
Original RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
|
||||||
|
super().__init__()
|
||||||
|
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
||||||
|
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
# Precompute inverse frequencies: theta_i = 1 / (base^(2i/d))
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
self._precompute_cache(max_seq_len)
|
||||||
|
|
||||||
|
def _precompute_cache(self, seq_len: int):
|
||||||
|
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq)
|
||||||
|
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||||
|
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||||
|
|
||||||
|
def _rotate_half(self, x: Tensor) -> Tensor:
|
||||||
|
"""Rotate half the hidden dims of the input.
|
||||||
|
|
||||||
|
For x = [x1, x2], returns [-x2, x1]
|
||||||
|
"""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
"""Apply rotary embeddings to query and key tensors."""
|
||||||
|
seq_len = q.shape[2]
|
||||||
|
|
||||||
|
if seq_len > self.max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}. "
|
||||||
|
f"Increase max_seq_len in RoPE config."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slice precomputed cache to actual sequence length
|
||||||
|
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
|
||||||
|
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
|
||||||
|
|
||||||
|
# Apply rotation: q_rot = q * cos + rotate_half(q) * sin
|
||||||
|
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
|
||||||
|
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
|
||||||
|
|
||||||
|
return q_rotated, k_rotated
|
||||||
|
|
||||||
|
|
||||||
|
class RoPEAttention(nn.Module):
|
||||||
|
"""Multi-head self-attention with Rotary Position Embedding (RoPE).
|
||||||
|
|
||||||
|
Custom attention implementation that applies RoPE to Q and K before
|
||||||
|
computing attention scores. This allows position information to be
|
||||||
|
encoded at every attention layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
max_seq_len: int = 512,
|
||||||
|
rope_base: float = 10000.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_size: Total hidden dimension
|
||||||
|
num_heads: Number of attention heads
|
||||||
|
dropout: Attention dropout rate
|
||||||
|
max_seq_len: Maximum sequence length for RoPE cache
|
||||||
|
rope_base: Base for RoPE frequency computation
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
||||||
|
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, T, hidden_size) input sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, T, hidden_size) attention output
|
||||||
|
"""
|
||||||
|
B, T, _ = x.shape # noqa: N806
|
||||||
|
|
||||||
|
# Compute Q, K, V
|
||||||
|
qkv = self.qkv_proj(x) # (B, T, 3 * hidden_size)
|
||||||
|
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||||
|
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, T, head_dim)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, num_heads, T, head_dim)
|
||||||
|
|
||||||
|
# Apply RoPE to Q and K
|
||||||
|
q, k = self.rope(q, k)
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
# Using PyTorch's efficient attention when available
|
||||||
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
|
||||||
|
) # (B, num_heads, T, head_dim)
|
||||||
|
|
||||||
|
# Reshape and project output
|
||||||
|
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) # (B, T, hidden_size)
|
||||||
|
output = self.out_proj(attn_out)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
"""DiT-style transformer block with AdaLN-Zero.
|
"""DiT-style transformer block with AdaLN-Zero.
|
||||||
|
|
||||||
@@ -78,11 +218,20 @@ class TransformerBlock(nn.Module):
|
|||||||
- shift_msa, scale_msa, gate_msa: for attention block
|
- shift_msa, scale_msa, gate_msa: for attention block
|
||||||
- shift_mlp, scale_mlp, gate_mlp: for MLP block
|
- shift_mlp, scale_mlp, gate_mlp: for MLP block
|
||||||
|
|
||||||
|
Supports both standard attention and RoPE attention.
|
||||||
|
|
||||||
Reference: https://github.com/facebookresearch/DiT
|
Reference: https://github.com/facebookresearch/DiT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0
|
self,
|
||||||
|
hidden_size: int = 128,
|
||||||
|
num_heads: int = 4,
|
||||||
|
num_features: int = 128,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
use_rope: bool = False,
|
||||||
|
max_seq_len: int = 512,
|
||||||
|
rope_base: float = 10000.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -90,12 +239,26 @@ class TransformerBlock(nn.Module):
|
|||||||
num_heads: Number of attention heads
|
num_heads: Number of attention heads
|
||||||
num_features: Size of conditioning features
|
num_features: Size of conditioning features
|
||||||
dropout: Dropout rate
|
dropout: Dropout rate
|
||||||
|
use_rope: Whether to use Rotary Position Embedding
|
||||||
|
max_seq_len: Maximum sequence length (for RoPE cache)
|
||||||
|
rope_base: Base frequency for RoPE
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(
|
self.use_rope = use_rope
|
||||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
|
||||||
)
|
if use_rope:
|
||||||
|
self.attn = RoPEAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rope_base=rope_base,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.multihead_attn = nn.MultiheadAttention(
|
||||||
|
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
|
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -128,7 +291,12 @@ class TransformerBlock(nn.Module):
|
|||||||
# Attention block: norm → modulate → attn → gate × output → residual
|
# Attention block: norm → modulate → attn → gate × output → residual
|
||||||
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
|
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
|
||||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
|
||||||
|
if self.use_rope:
|
||||||
|
attn_out = self.attn(attn_input)
|
||||||
|
else:
|
||||||
|
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||||
|
|
||||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||||
|
|
||||||
# MLP block: norm → modulate → mlp → gate × output → residual
|
# MLP block: norm → modulate → mlp → gate × output → residual
|
||||||
@@ -163,6 +331,7 @@ class DiffusionTransformer(nn.Module):
|
|||||||
self.num_layers = self.transformer_config.num_layers
|
self.num_layers = self.transformer_config.num_layers
|
||||||
self.num_heads = self.transformer_config.num_heads
|
self.num_heads = self.transformer_config.num_heads
|
||||||
self.dropout = self.transformer_config.dropout
|
self.dropout = self.transformer_config.dropout
|
||||||
|
self.use_rope = self.transformer_config.use_rope
|
||||||
|
|
||||||
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
@@ -179,7 +348,7 @@ class DiffusionTransformer(nn.Module):
|
|||||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||||
|
|
||||||
if self.transformer_config.use_positional_encoding:
|
if self.transformer_config.use_positional_encoding:
|
||||||
# Learnable positional embeddings for sequence positions
|
# Learnable positional embeddings for sequence positions (absolute encoding)
|
||||||
self.pos_embedding = nn.Parameter(
|
self.pos_embedding = nn.Parameter(
|
||||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||||
)
|
)
|
||||||
@@ -193,6 +362,9 @@ class DiffusionTransformer(nn.Module):
|
|||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_features=self.cond_dim,
|
num_features=self.cond_dim,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
|
use_rope=self.use_rope,
|
||||||
|
max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences
|
||||||
|
rope_base=getattr(self.transformer_config, "rope_base", 10000.0),
|
||||||
)
|
)
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user