mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(policies): replace deprecated torch.cuda.amp.autocast with torch.amp.autocast (#3167)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -27,7 +27,7 @@ import torch.distributed as distributed
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from einops import pack, rearrange, reduce, repeat, unpack
|
||||
from torch import einsum, nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .configuration_vqbet import VQBeTConfig
|
||||
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
|
||||
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
||||
self.replace(batch_samples, batch_mask=expired_codes)
|
||||
|
||||
@autocast(enabled=False)
|
||||
@autocast("cuda", enabled=False)
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
|
||||
Reference in New Issue
Block a user