diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index e81fb4723..a2e981c4c 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -174,9 +174,7 @@ class VLAJEPAModel(nn.Module): embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True) action_idx = None if need_action_tokens: - action_mask = torch.isin( - input_ids, torch.tensor(self.action_token_ids, device=input_ids.device) - ) + action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device)) action_idx = action_mask.nonzero(as_tuple=True) device_type = next(self.parameters()).device.type