mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
align VLA-JEPA architecture with original checkpoint
- Remove stale `action_num_heads` / `action_attention_head_dim` config fields; DiT head dimensions are now always derived from the preset (DiT-B/L/test). - Add `num_target_vision_tokens` and `action_max_seq_len` config fields required by the action head's future-token embedding and positional embedding tables. - Fix default `qwen_model_name` to 2B (matches all released checkpoints). - Rename `ActionEncoder` attrs w1/w2/w3 → layer1/layer2/layer3 to match checkpoint key names; replace `nn.Sequential` decoder/state-encoder with `_MLP2` (layer1/layer2 naming). - Fix `VLAJEPAActionHead` to size ActionEncoder and StateEncoder at `inner_dim` (DiT input width) rather than `action_hidden_size` (DiT output width). - Rename `DiT.blocks` → `transformer_blocks` and `attn` → `attn1` to match checkpoint; add alternating cross/self attention (even blocks cross-attend to Qwen context, odd blocks self-attend). - Add `DiT-test` preset for unit tests. - Rewrite `ActionConditionedVideoPredictor` with explicit ViT-style blocks (`_PredictorBlock` with fused qkv) to match checkpoint structure; rename `encoder`/`norm`/`proj` → `predictor_blocks`/`predictor_norm`/`predictor_proj`.
This commit is contained in:
@@ -33,6 +33,18 @@ def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class _MLP2(nn.Module):
|
||||
"""Two-layer GELU MLP with layer1/layer2 attribute names matching the original checkpoint."""
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layer2(F.gelu(self.layer1(x)))
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
@@ -51,9 +63,9 @@ class SinusoidalPositionalEncoding(nn.Module):
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(action_dim, hidden_size)
|
||||
self.w2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.w3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
@@ -61,9 +73,9 @@ class ActionEncoder(nn.Module):
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.w1(actions)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.w3(swish(self.w2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
@@ -96,11 +108,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
cross_attention_dim: int | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
@@ -115,11 +128,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
hidden_states = hidden_states + self.attn(attn_input, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
@@ -140,16 +153,17 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.blocks = nn.ModuleList(
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
# Even blocks attend to context (cross-attention), odd blocks are self-attention.
|
||||
cross_attention_dim=cross_attention_dim if i % 2 == 0 else None,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
@@ -164,8 +178,9 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
for block in self.transformer_blocks:
|
||||
es = encoder_hidden_states if block.is_cross_attention else None
|
||||
x = block(x, encoder_hidden_states=es, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
@@ -181,6 +196,7 @@ class ActionModelPreset:
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
@@ -189,37 +205,34 @@ class VLAJEPAActionHead(nn.Module):
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
self.input_embedding_dim = preset.hidden_size
|
||||
num_heads = preset.num_attention_heads
|
||||
head_dim = preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.future_action_window_size + 1
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.model = DiT(
|
||||
num_attention_heads=config.action_num_heads or preset.num_attention_heads,
|
||||
attention_head_dim=config.action_attention_head_dim or preset.attention_head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=config.action_hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, config.action_hidden_size)
|
||||
self.action_decoder = nn.Sequential(
|
||||
nn.Linear(config.action_hidden_size, config.action_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(config.action_hidden_size, config.action_dim),
|
||||
)
|
||||
# action_encoder/decoder and state_encoder use action_hidden_size (DiT output dim).
|
||||
# action_encoder and state_encoder produce inner_dim-sized tokens (DiT input width).
|
||||
# action_decoder takes DiT output (action_hidden_size) and produces action_dim predictions.
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = _MLP2(config.action_hidden_size, config.action_hidden_size, config.action_dim)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
nn.Linear(config.state_dim, config.action_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(config.action_hidden_size, config.action_hidden_size),
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size
|
||||
_MLP2(config.state_dim, config.action_hidden_size, inner_dim) if config.state_dim > 0 else None
|
||||
)
|
||||
# future_tokens and position_embedding operate at inner_dim (DiT input width),
|
||||
# not at action_hidden_size (DiT output width).
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, inner_dim)
|
||||
self.position_embedding = nn.Embedding(config.action_max_seq_len, inner_dim)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
@@ -250,6 +263,7 @@ class VLAJEPAActionHead(nn.Module):
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
@@ -264,7 +278,14 @@ class VLAJEPAActionHead(nn.Module):
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
return F.mse_loss(pred_actions, velocity, reduction="mean")
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
|
||||
@@ -23,7 +23,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
@@ -44,13 +44,13 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 12
|
||||
action_num_heads: int = 16
|
||||
action_attention_head_dim: int = 64
|
||||
action_dropout: float = 0.1
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 16
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -10,11 +11,50 @@ def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond
|
||||
for current_step in range(num_steps):
|
||||
row_start = current_step * (tokens_per_step + cond_tokens)
|
||||
row_end = row_start + tokens_per_step + cond_tokens
|
||||
allowed_end = row_end
|
||||
mask[row_start:row_end, :allowed_end] = 0
|
||||
mask[row_start:row_end, :row_end] = 0
|
||||
return mask
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
|
||||
self.proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||
b, n, c = x.shape
|
||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
return self.proj(x.transpose(1, 2).reshape(b, n, c))
|
||||
|
||||
|
||||
class _MLP(nn.Module):
|
||||
def __init__(self, embed_dim: int, mlp_ratio: float) -> None:
|
||||
super().__init__()
|
||||
hidden = int(embed_dim * mlp_ratio)
|
||||
self.fc1 = nn.Linear(embed_dim, hidden)
|
||||
self.fc2 = nn.Linear(hidden, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _PredictorBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.attn = _Attention(embed_dim, num_heads)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.mlp = _MLP(embed_dim, mlp_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,17 +69,11 @@ class ActionConditionedVideoPredictor(nn.Module):
|
||||
super().__init__()
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=predictor_embed_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=int(predictor_embed_dim * mlp_ratio),
|
||||
dropout=0.0,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[_PredictorBlock(predictor_embed_dim, num_heads, mlp_ratio) for _ in range(depth)]
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
self.norm = nn.LayerNorm(predictor_embed_dim)
|
||||
self.proj = nn.Linear(predictor_embed_dim, embed_dim)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor:
|
||||
@@ -50,9 +84,9 @@ class ActionConditionedVideoPredictor(nn.Module):
|
||||
|
||||
frame_tokens = self.predictor_embed(frame_tokens)
|
||||
action_tokens = self.action_encoder(action_tokens)
|
||||
fused_steps = []
|
||||
for step in range(num_steps):
|
||||
fused_steps.append(torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1))
|
||||
fused_steps = [
|
||||
torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1) for step in range(num_steps)
|
||||
]
|
||||
fused = torch.cat(fused_steps, dim=1)
|
||||
|
||||
attn_mask = build_block_causal_attention_mask(
|
||||
@@ -60,7 +94,11 @@ class ActionConditionedVideoPredictor(nn.Module):
|
||||
tokens_per_step=tokens_per_frame,
|
||||
cond_tokens=self.num_action_tokens_per_step,
|
||||
).to(device=fused.device, dtype=fused.dtype)
|
||||
encoded = self.encoder(fused, mask=attn_mask)
|
||||
encoded = encoded.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
|
||||
predicted_frame_tokens = encoded[:, :, self.num_action_tokens_per_step :, :]
|
||||
return self.proj(self.norm(predicted_frame_tokens))
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
fused = block(fused, attn_mask=attn_mask)
|
||||
|
||||
fused = self.predictor_norm(fused)
|
||||
fused = fused.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
|
||||
predicted_frame_tokens = fused[:, :, self.num_action_tokens_per_step :, :]
|
||||
return self.predictor_proj(predicted_frame_tokens)
|
||||
|
||||
Reference in New Issue
Block a user