fix(policies): replace deprecated torch.cuda.amp.autocast with torch.amp.autocast (#3167)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Defalt
2026-04-19 17:25:08 +03:00
committed by GitHub
parent 3f16d98a9b
commit 5c43fa1cce
+2 -2
View File
@@ -27,7 +27,7 @@ import torch.distributed as distributed
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from einops import pack, rearrange, reduce, repeat, unpack from einops import pack, rearrange, reduce, repeat, unpack
from torch import einsum, nn from torch import einsum, nn
from torch.cuda.amp import autocast from torch.amp import autocast
from torch.optim import Optimizer from torch.optim import Optimizer
from .configuration_vqbet import VQBeTConfig from .configuration_vqbet import VQBeTConfig
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
self.replace(batch_samples, batch_mask=expired_codes) self.replace(batch_samples, batch_mask=expired_codes)
@autocast(enabled=False) @autocast("cuda", enabled=False)
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4 needs_codebook_dim = x.ndim < 4
sample_codebook_temp = ( sample_codebook_temp = (