This commit is contained in:
Maximellerbach
2026-06-09 17:12:50 +02:00
parent 31ddb8f493
commit c1332ac37e
@@ -174,9 +174,7 @@ class VLAJEPAModel(nn.Module):
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
action_idx = None
if need_action_tokens:
action_mask = torch.isin(
input_ids, torch.tensor(self.action_token_ids, device=input_ids.device)
)
action_mask = torch.isin(input_ids, torch.tensor(self.action_token_ids, device=input_ids.device))
action_idx = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type