diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index ee9d2d7da..fa8f90508 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -33,6 +33,18 @@ def swish(x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x) +class _MLP2(nn.Module): + """Two-layer GELU MLP with layer1/layer2 attribute names matching the original checkpoint.""" + + def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None: + super().__init__() + self.layer1 = nn.Linear(in_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer2(F.gelu(self.layer1(x))) + + class SinusoidalPositionalEncoding(nn.Module): def __init__(self, embedding_dim: int): super().__init__() @@ -51,9 +63,9 @@ class SinusoidalPositionalEncoding(nn.Module): class ActionEncoder(nn.Module): def __init__(self, action_dim: int, hidden_size: int): super().__init__() - self.w1 = nn.Linear(action_dim, hidden_size) - self.w2 = nn.Linear(hidden_size * 2, hidden_size) - self.w3 = nn.Linear(hidden_size, hidden_size) + self.layer1 = nn.Linear(action_dim, hidden_size) + self.layer2 = nn.Linear(hidden_size * 2, hidden_size) + self.layer3 = nn.Linear(hidden_size, hidden_size) self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: @@ -61,9 +73,9 @@ class ActionEncoder(nn.Module): if timesteps.ndim != 1 or timesteps.shape[0] != batch_size: raise ValueError("timesteps must have shape [batch_size].") timesteps = timesteps.unsqueeze(1).expand(-1, seq_len) - action_emb = self.w1(actions) + action_emb = self.layer1(actions) time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype) - return self.w3(swish(self.w2(torch.cat([action_emb, time_emb], dim=-1)))) + return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1)))) class TimestepEncoder(nn.Module): @@ -96,11 +108,12 @@ class BasicTransformerBlock(nn.Module): num_attention_heads: int, attention_head_dim: int, dropout: float, - cross_attention_dim: int, + cross_attention_dim: int | None, ) -> None: super().__init__() + self.is_cross_attention = cross_attention_dim is not None self.norm1 = AdaLayerNorm(dim) - self.attn = Attention( + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, @@ -115,11 +128,11 @@ class BasicTransformerBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, temb: torch.Tensor, ) -> torch.Tensor: attn_input = self.norm1(hidden_states, temb) - hidden_states = hidden_states + self.attn(attn_input, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=encoder_hidden_states) hidden_states = hidden_states + self.ff(self.norm2(hidden_states)) return hidden_states @@ -140,16 +153,17 @@ class DiT(ModelMixin, ConfigMixin): super().__init__() self.inner_dim = num_attention_heads * attention_head_dim self.timestep_encoder = TimestepEncoder(self.inner_dim) - self.blocks = nn.ModuleList( + self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, dropout=dropout, - cross_attention_dim=cross_attention_dim, + # Even blocks attend to context (cross-attention), odd blocks are self-attention. + cross_attention_dim=cross_attention_dim if i % 2 == 0 else None, ) - for _ in range(num_layers) + for i in range(num_layers) ] ) self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False) @@ -164,8 +178,9 @@ class DiT(ModelMixin, ConfigMixin): ) -> torch.Tensor: temb = self.timestep_encoder(timestep) x = hidden_states - for block in self.blocks: - x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb) + for block in self.transformer_blocks: + es = encoder_hidden_states if block.is_cross_attention else None + x = block(x, encoder_hidden_states=es, temb=temb) shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1) x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None] return self.proj_out_2(x) @@ -181,6 +196,7 @@ class ActionModelPreset: DIT_PRESETS = { "DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12), "DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32), + "DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2), } @@ -189,37 +205,34 @@ class VLAJEPAActionHead(nn.Module): super().__init__() preset = DIT_PRESETS[config.action_model_type] self.config = config - self.input_embedding_dim = preset.hidden_size + num_heads = preset.num_attention_heads + head_dim = preset.attention_head_dim + inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768 + + self.input_embedding_dim = inner_dim self.action_horizon = config.future_action_window_size + 1 self.num_inference_timesteps = config.num_inference_timesteps self.model = DiT( - num_attention_heads=config.action_num_heads or preset.num_attention_heads, - attention_head_dim=config.action_attention_head_dim or preset.attention_head_dim, + num_attention_heads=num_heads, + attention_head_dim=head_dim, output_dim=config.action_hidden_size, num_layers=config.action_num_layers, dropout=config.action_dropout, cross_attention_dim=cross_attention_dim, ) - self.action_encoder = ActionEncoder(config.action_dim, config.action_hidden_size) - self.action_decoder = nn.Sequential( - nn.Linear(config.action_hidden_size, config.action_hidden_size), - nn.GELU(), - nn.Linear(config.action_hidden_size, config.action_dim), - ) + # action_encoder/decoder and state_encoder use action_hidden_size (DiT output dim). + # action_encoder and state_encoder produce inner_dim-sized tokens (DiT input width). + # action_decoder takes DiT output (action_hidden_size) and produces action_dim predictions. + self.action_encoder = ActionEncoder(config.action_dim, inner_dim) + self.action_decoder = _MLP2(config.action_hidden_size, config.action_hidden_size, config.action_dim) self.state_encoder = ( - nn.Sequential( - nn.Linear(config.state_dim, config.action_hidden_size), - nn.GELU(), - nn.Linear(config.action_hidden_size, config.action_hidden_size), - ) - if config.state_dim > 0 - else None - ) - self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size) - self.position_embedding = nn.Embedding( - config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size + _MLP2(config.state_dim, config.action_hidden_size, inner_dim) if config.state_dim > 0 else None ) + # future_tokens and position_embedding operate at inner_dim (DiT input width), + # not at action_hidden_size (DiT output width). + self.future_tokens = nn.Embedding(config.num_target_vision_tokens, inner_dim) + self.position_embedding = nn.Embedding(config.action_max_seq_len, inner_dim) self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta) def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: @@ -250,6 +263,7 @@ class VLAJEPAActionHead(nn.Module): conditioning_tokens: torch.Tensor, actions: torch.Tensor, state: torch.Tensor | None = None, + action_is_pad: torch.Tensor | None = None, ) -> torch.Tensor: noise = torch.randn_like(actions) t = self.sample_time(actions.shape[0], actions.device, actions.dtype) @@ -264,7 +278,14 @@ class VLAJEPAActionHead(nn.Module): timestep=t_discretized, ) pred_actions = self.action_decoder(pred[:, -actions.shape[1] :]) - return F.mse_loss(pred_actions, velocity, reduction="mean") + + if action_is_pad is None: + action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device) + + loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim] + valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1] + num_valid = valid_mask.sum() * loss.shape[-1] + return (loss * valid_mask).sum() / num_valid.clamp_min(1) @torch.no_grad() def predict_action( diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 65070f62b..b23594101 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -23,7 +23,7 @@ class VLAJEPAConfig(PreTrainedConfig): } ) - qwen_model_name: str = "Qwen/Qwen3-VL-4B-Instruct" + qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct" jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" tokenizer_padding_side: str = "left" @@ -44,13 +44,13 @@ class VLAJEPAConfig(PreTrainedConfig): action_hidden_size: int = 1024 action_model_type: str = "DiT-B" action_num_layers: int = 12 - action_num_heads: int = 16 - action_attention_head_dim: int = 64 action_dropout: float = 0.1 action_num_timestep_buckets: int = 1000 action_noise_beta_alpha: float = 1.5 action_noise_beta_beta: float = 1.0 action_noise_s: float = 0.999 + num_target_vision_tokens: int = 32 + action_max_seq_len: int = 1024 # total video frames loaded per sample num_video_frames: int = 16 diff --git a/src/lerobot/policies/vla_jepa/world_model.py b/src/lerobot/policies/vla_jepa/world_model.py index 4e32706eb..4a398e7df 100644 --- a/src/lerobot/policies/vla_jepa/world_model.py +++ b/src/lerobot/policies/vla_jepa/world_model.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +import torch.nn.functional as F # noqa: N812 from torch import nn @@ -10,11 +11,50 @@ def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond for current_step in range(num_steps): row_start = current_step * (tokens_per_step + cond_tokens) row_end = row_start + tokens_per_step + cond_tokens - allowed_end = row_end - mask[row_start:row_end, :allowed_end] = 0 + mask[row_start:row_end, :row_end] = 0 return mask +class _Attention(nn.Module): + def __init__(self, embed_dim: int, num_heads: int) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) + self.proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + return self.proj(x.transpose(1, 2).reshape(b, n, c)) + + +class _MLP(nn.Module): + def __init__(self, embed_dim: int, mlp_ratio: float) -> None: + super().__init__() + hidden = int(embed_dim * mlp_ratio) + self.fc1 = nn.Linear(embed_dim, hidden) + self.fc2 = nn.Linear(hidden, embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(F.gelu(self.fc1(x))) + + +class _PredictorBlock(nn.Module): + def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(embed_dim) + self.attn = _Attention(embed_dim, num_heads) + self.norm2 = nn.LayerNorm(embed_dim) + self.mlp = _MLP(embed_dim, mlp_ratio) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x), attn_mask=attn_mask) + return x + self.mlp(self.norm2(x)) + + class ActionConditionedVideoPredictor(nn.Module): def __init__( self, @@ -29,17 +69,11 @@ class ActionConditionedVideoPredictor(nn.Module): super().__init__() self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim) self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim) - encoder_layer = nn.TransformerEncoderLayer( - d_model=predictor_embed_dim, - nhead=num_heads, - dim_feedforward=int(predictor_embed_dim * mlp_ratio), - dropout=0.0, - activation="gelu", - batch_first=True, + self.predictor_blocks = nn.ModuleList( + [_PredictorBlock(predictor_embed_dim, num_heads, mlp_ratio) for _ in range(depth)] ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) - self.norm = nn.LayerNorm(predictor_embed_dim) - self.proj = nn.Linear(predictor_embed_dim, embed_dim) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim) self.num_action_tokens_per_step = num_action_tokens_per_step def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor: @@ -50,9 +84,9 @@ class ActionConditionedVideoPredictor(nn.Module): frame_tokens = self.predictor_embed(frame_tokens) action_tokens = self.action_encoder(action_tokens) - fused_steps = [] - for step in range(num_steps): - fused_steps.append(torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1)) + fused_steps = [ + torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1) for step in range(num_steps) + ] fused = torch.cat(fused_steps, dim=1) attn_mask = build_block_causal_attention_mask( @@ -60,7 +94,11 @@ class ActionConditionedVideoPredictor(nn.Module): tokens_per_step=tokens_per_frame, cond_tokens=self.num_action_tokens_per_step, ).to(device=fused.device, dtype=fused.dtype) - encoded = self.encoder(fused, mask=attn_mask) - encoded = encoded.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1) - predicted_frame_tokens = encoded[:, :, self.num_action_tokens_per_step :, :] - return self.proj(self.norm(predicted_frame_tokens)) + + for block in self.predictor_blocks: + fused = block(fused, attn_mask=attn_mask) + + fused = self.predictor_norm(fused) + fused = fused.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1) + predicted_frame_tokens = fused[:, :, self.num_action_tokens_per_step :, :] + return self.predictor_proj(predicted_frame_tokens)