mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 22:27:14 +00:00
fix mem bug
This commit is contained in:
@@ -23,17 +23,26 @@ class PerTimestepNormalizer:
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.eps = eps
|
||||
self._cache = {} # Cache for device/dtype converted tensors
|
||||
|
||||
def _get_stats(self, device, dtype):
|
||||
"""Get cached stats for device/dtype, or create and cache them."""
|
||||
key = (device, dtype)
|
||||
if key not in self._cache:
|
||||
self._cache[key] = (
|
||||
self.mean.to(device, dtype),
|
||||
self.std.to(device, dtype),
|
||||
)
|
||||
return self._cache[key]
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean = self.mean.to(x.device, x.dtype)
|
||||
std = self.std.to(x.device, x.dtype)
|
||||
mean, std = self._get_stats(x.device, x.dtype)
|
||||
if x.dim() == 3 and mean.dim() == 2:
|
||||
mean, std = mean.unsqueeze(0), std.unsqueeze(0)
|
||||
return (x - mean) / (std + self.eps)
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean = self.mean.to(x.device, x.dtype)
|
||||
std = self.std.to(x.device, x.dtype)
|
||||
mean, std = self._get_stats(x.device, x.dtype)
|
||||
if x.dim() == 3 and mean.dim() == 2:
|
||||
mean, std = mean.unsqueeze(0), std.unsqueeze(0)
|
||||
return x * (std + self.eps) + mean
|
||||
|
||||
Reference in New Issue
Block a user