mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
fix(policies): replace deprecated torch.cuda.amp.autocast with torch.amp.autocast (#3167)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -27,7 +27,7 @@ import torch.distributed as distributed
|
|||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from einops import pack, rearrange, reduce, repeat, unpack
|
from einops import pack, rearrange, reduce, repeat, unpack
|
||||||
from torch import einsum, nn
|
from torch import einsum, nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.amp import autocast
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from .configuration_vqbet import VQBeTConfig
|
from .configuration_vqbet import VQBeTConfig
|
||||||
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
||||||
self.replace(batch_samples, batch_mask=expired_codes)
|
self.replace(batch_samples, batch_mask=expired_codes)
|
||||||
|
|
||||||
@autocast(enabled=False)
|
@autocast("cuda", enabled=False)
|
||||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||||
needs_codebook_dim = x.ndim < 4
|
needs_codebook_dim = x.ndim < 4
|
||||||
sample_codebook_temp = (
|
sample_codebook_temp = (
|
||||||
|
|||||||
Reference in New Issue
Block a user