From ab5222c819cd4ddccb85dfbb6cc38ef150c5bf3d Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Fri, 15 May 2026 14:34:21 +0200 Subject: [PATCH] propagate action_is_pad masking through VLA-JEPA policy pipeline Pass the `action_is_pad` tensor from the batch through to the action head so padded timesteps are excluded from the flow-matching loss. --- .../policies/vla_jepa/modeling_vla_jepa.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index ecdaef978..448d34df8 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -123,6 +123,11 @@ class VLAJEPAModel(nn.Module): actions = [ex["action"] for ex in examples] if has_action else None has_state = "state" in examples[0] and examples[0]["state"] is not None state = [ex["state"] for ex in examples] if has_state else None + action_is_pad = ( + [ex["action_is_pad"] for ex in examples] + if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None + else None + ) # Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W] batch_videos = np.stack(batch_videos) @@ -243,7 +248,20 @@ class VLAJEPAModel(nn.Module): actions_rep = actions_target.repeat(num_repeated, 1, 1) state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None - action_loss = self.action_model(embodied_rep, actions_rep, state_rep) + action_is_pad_rep = None + if action_is_pad is not None: + pad_tensor = torch.stack( + [ + p.to(actions_target.device) + if isinstance(p, Tensor) + else torch.tensor(p, device=actions_target.device) + for p in action_is_pad + ] + ) # [B, T_full] + pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon] + action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon] + + action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep) return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} @@ -414,11 +432,15 @@ class VLAJEPAPolicy(PreTrainedPolicy): # ---- Collect actions (training only) ---- actions_list = None + action_is_pad_list = None actions_tensor = batch.get(ACTION) if actions_tensor is not None: if actions_tensor.ndim == 2: actions_tensor = actions_tensor.unsqueeze(1) actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] + action_is_pad_tensor = batch.get("action_is_pad") + if action_is_pad_tensor is not None: + action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)] # ---- Collect state ---- state_list = None @@ -440,6 +462,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): } if actions_list is not None: example["action"] = actions_list[b] + if action_is_pad_list is not None: + example["action_is_pad"] = action_is_pad_list[b] if state_list is not None: example["state"] = state_list[b] examples.append(example)