propagate action_is_pad masking through VLA-JEPA policy pipeline

Pass the `action_is_pad` tensor from the batch through to the action head
so padded timesteps are excluded from the flow-matching loss.
This commit is contained in:
Maximellerbach
2026-05-15 14:34:21 +02:00
parent 3db774e0e2
commit ab5222c819
@@ -123,6 +123,11 @@ class VLAJEPAModel(nn.Module):
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
@@ -243,7 +248,20 @@ class VLAJEPAModel(nn.Module):
actions_rep = actions_target.repeat(num_repeated, 1, 1)
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon]
action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
@@ -414,11 +432,15 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
@@ -440,6 +462,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)