diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index e4b576445..09d16bbed 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -167,9 +167,13 @@ class TransformerConfig: num_layers: int = 6 # Number of transformer layers num_heads: int = 8 # Number of attention heads dropout: float = 0.1 # Dropout rate - use_positional_encoding: bool = True # Whether to use positional encoding + use_positional_encoding: bool = False # Whether to use absolute positional encoding diffusion_step_embed_dim: int = 256 # Timestep embedding size + # RoPE (Rotary Position Embedding) configuration + use_rope: bool = True # Whether to use Rotary Position Embedding in attention (baseline is True) + rope_base: float = 10000.0 # Base frequency for RoPE computation + def __post_init__(self): """Validate Transformer-specific parameters.""" if self.hidden_dim <= 0: diff --git a/src/lerobot/policies/multi_task_dit/modules/transformer.py b/src/lerobot/policies/multi_task_dit/modules/transformer.py index 3f8415574..7631d07d6 100644 --- a/src/lerobot/policies/multi_task_dit/modules/transformer.py +++ b/src/lerobot/policies/multi_task_dit/modules/transformer.py @@ -71,6 +71,146 @@ class SinusoidalPosEmb(nn.Module): return emb +class RotaryPositionalEmbedding(nn.Module): + """Rotary Position Embedding (RoPE) for transformers. + + RoPE encodes position information by rotating query and key vectors, + which naturally captures relative positions through the dot product. + Applied at every attention layer rather than once at input. + + To do this, we need to reimplement the attention mechanism to apply RoPE + to Q and K before computing the attention scores, so we cannot use the + the built-in MultiheadAttention module. + + Original RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer) + """ + + def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0): + super().__init__() + assert head_dim % 2 == 0, "head_dim must be even for RoPE" + + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.base = base + + # Precompute inverse frequencies: theta_i = 1 / (base^(2i/d)) + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self._precompute_cache(max_seq_len) + + def _precompute_cache(self, seq_len: int): + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + + self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def _rotate_half(self, x: Tensor) -> Tensor: + """Rotate half the hidden dims of the input. + + For x = [x1, x2], returns [-x2, x1] + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]: + """Apply rotary embeddings to query and key tensors.""" + seq_len = q.shape[2] + + if seq_len > self.max_seq_len: + raise ValueError( + f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}. " + f"Increase max_seq_len in RoPE config." + ) + + # Slice precomputed cache to actual sequence length + cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype) + sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype) + + # Apply rotation: q_rot = q * cos + rotate_half(q) * sin + q_rotated = (q * cos) + (self._rotate_half(q) * sin) + k_rotated = (k * cos) + (self._rotate_half(k) * sin) + + return q_rotated, k_rotated + + +class RoPEAttention(nn.Module): + """Multi-head self-attention with Rotary Position Embedding (RoPE). + + Custom attention implementation that applies RoPE to Q and K before + computing attention scores. This allows position information to be + encoded at every attention layer. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + max_seq_len: int = 512, + rope_base: float = 10000.0, + ): + """ + Args: + hidden_size: Total hidden dimension + num_heads: Number of attention heads + dropout: Attention dropout rate + max_seq_len: Maximum sequence length for RoPE cache + rope_base: Base for RoPE frequency computation + """ + super().__init__() + assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, T, hidden_size) input sequence + + Returns: + (B, T, hidden_size) attention output + """ + B, T, _ = x.shape # noqa: N806 + + # Compute Q, K, V + qkv = self.qkv_proj(x) # (B, T, 3 * hidden_size) + qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, T, head_dim) + q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, num_heads, T, head_dim) + + # Apply RoPE to Q and K + q, k = self.rope(q, k) + + # Scaled dot-product attention + # Using PyTorch's efficient attention when available + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0, + ) # (B, num_heads, T, head_dim) + + # Reshape and project output + attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) # (B, T, hidden_size) + output = self.out_proj(attn_out) + + return output + + class TransformerBlock(nn.Module): """DiT-style transformer block with AdaLN-Zero. @@ -78,11 +218,20 @@ class TransformerBlock(nn.Module): - shift_msa, scale_msa, gate_msa: for attention block - shift_mlp, scale_mlp, gate_mlp: for MLP block + Supports both standard attention and RoPE attention. + Reference: https://github.com/facebookresearch/DiT """ def __init__( - self, hidden_size: int = 128, num_heads: int = 4, num_features: int = 128, dropout: float = 0.0 + self, + hidden_size: int = 128, + num_heads: int = 4, + num_features: int = 128, + dropout: float = 0.0, + use_rope: bool = False, + max_seq_len: int = 512, + rope_base: float = 10000.0, ): """ Args: @@ -90,12 +239,26 @@ class TransformerBlock(nn.Module): num_heads: Number of attention heads num_features: Size of conditioning features dropout: Dropout rate + use_rope: Whether to use Rotary Position Embedding + max_seq_len: Maximum sequence length (for RoPE cache) + rope_base: Base frequency for RoPE """ super().__init__() - self.multihead_attn = nn.MultiheadAttention( - hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout - ) + self.use_rope = use_rope + + if use_rope: + self.attn = RoPEAttention( + hidden_size=hidden_size, + num_heads=num_heads, + dropout=dropout, + max_seq_len=max_seq_len, + rope_base=rope_base, + ) + else: + self.multihead_attn = nn.MultiheadAttention( + hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout + ) # Layer normalizations (no learnable affine parameters, all adaptation via conditioning) self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -128,7 +291,12 @@ class TransformerBlock(nn.Module): # Attention block: norm → modulate → attn → gate × output → residual # modulate requires unsqueeze(1) to add sequence dimension for broadcasting attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1)) - attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input) + + if self.use_rope: + attn_out = self.attn(attn_input) + else: + attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input) + x = x + gate_msa.unsqueeze(1) * attn_out # MLP block: norm → modulate → mlp → gate × output → residual @@ -163,6 +331,7 @@ class DiffusionTransformer(nn.Module): self.num_layers = self.transformer_config.num_layers self.num_heads = self.transformer_config.num_heads self.dropout = self.transformer_config.dropout + self.use_rope = self.transformer_config.use_rope self.timestep_embed_dim = self.transformer_config.diffusion_step_embed_dim self.time_mlp = nn.Sequential( @@ -179,7 +348,7 @@ class DiffusionTransformer(nn.Module): self.input_proj = nn.Linear(self.action_dim, self.hidden_size) if self.transformer_config.use_positional_encoding: - # Learnable positional embeddings for sequence positions + # Learnable positional embeddings for sequence positions (absolute encoding) self.pos_embedding = nn.Parameter( torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) ) @@ -193,6 +362,9 @@ class DiffusionTransformer(nn.Module): num_heads=self.num_heads, num_features=self.cond_dim, dropout=self.dropout, + use_rope=self.use_rope, + max_seq_len=self.horizon, # This remains fixed because we aren't generating variable length sequences + rope_base=getattr(self.transformer_config, "rope_base", 10000.0), ) for _ in range(self.num_layers) ]