fix mem bug

This commit is contained in:
Pepijn
2026-01-03 11:34:31 +01:00
parent c5f66edff9
commit 574081ac02
+13 -4
View File
@@ -23,17 +23,26 @@ class PerTimestepNormalizer:
self.mean = mean
self.std = std
self.eps = eps
self._cache = {} # Cache for device/dtype converted tensors
def _get_stats(self, device, dtype):
"""Get cached stats for device/dtype, or create and cache them."""
key = (device, dtype)
if key not in self._cache:
self._cache[key] = (
self.mean.to(device, dtype),
self.std.to(device, dtype),
)
return self._cache[key]
def normalize(self, x: torch.Tensor) -> torch.Tensor:
mean = self.mean.to(x.device, x.dtype)
std = self.std.to(x.device, x.dtype)
mean, std = self._get_stats(x.device, x.dtype)
if x.dim() == 3 and mean.dim() == 2:
mean, std = mean.unsqueeze(0), std.unsqueeze(0)
return (x - mean) / (std + self.eps)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
mean = self.mean.to(x.device, x.dtype)
std = self.std.to(x.device, x.dtype)
mean, std = self._get_stats(x.device, x.dtype)
if x.dim() == 3 and mean.dim() == 2:
mean, std = mean.unsqueeze(0), std.unsqueeze(0)
return x * (std + self.eps) + mean