fix(policies): replace deprecated torch.cuda.amp.autocast with torch.amp.autocast (#3167)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Defalt
2026-04-19 17:25:08 +03:00
committed by GitHub
parent 3f16d98a9b
commit 5c43fa1cce
+2 -2
View File
@@ -27,7 +27,7 @@ import torch.distributed as distributed
import torch.nn.functional as F # noqa: N812
from einops import pack, rearrange, reduce, repeat, unpack
from torch import einsum, nn
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.optim import Optimizer
from .configuration_vqbet import VQBeTConfig
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
self.replace(batch_samples, batch_mask=expired_codes)
@autocast(enabled=False)
@autocast("cuda", enabled=False)
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (