From 574081ac02b7d3da19edb220b2d322ac13677beb Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 3 Jan 2026 11:34:31 +0100 Subject: [PATCH] fix mem bug --- src/lerobot/utils/relative_actions.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/lerobot/utils/relative_actions.py b/src/lerobot/utils/relative_actions.py index ee58741b9..888cb63a0 100644 --- a/src/lerobot/utils/relative_actions.py +++ b/src/lerobot/utils/relative_actions.py @@ -23,17 +23,26 @@ class PerTimestepNormalizer: self.mean = mean self.std = std self.eps = eps + self._cache = {} # Cache for device/dtype converted tensors + + def _get_stats(self, device, dtype): + """Get cached stats for device/dtype, or create and cache them.""" + key = (device, dtype) + if key not in self._cache: + self._cache[key] = ( + self.mean.to(device, dtype), + self.std.to(device, dtype), + ) + return self._cache[key] def normalize(self, x: torch.Tensor) -> torch.Tensor: - mean = self.mean.to(x.device, x.dtype) - std = self.std.to(x.device, x.dtype) + mean, std = self._get_stats(x.device, x.dtype) if x.dim() == 3 and mean.dim() == 2: mean, std = mean.unsqueeze(0), std.unsqueeze(0) return (x - mean) / (std + self.eps) def unnormalize(self, x: torch.Tensor) -> torch.Tensor: - mean = self.mean.to(x.device, x.dtype) - std = self.std.to(x.device, x.dtype) + mean, std = self._get_stats(x.device, x.dtype) if x.dim() == 3 and mean.dim() == 2: mean, std = mean.unsqueeze(0), std.unsqueeze(0) return x * (std + self.eps) + mean