check normalization before applying

This commit is contained in:
Pepijn
2025-09-25 10:16:02 +02:00
parent 2740420d87
commit 3095c1e2bd
+32 -6
View File
@@ -322,7 +322,14 @@ class _NormalizationMixin:
stats = self._tensor_stats[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
if norm_mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean", None)
std = stats.get("std", None)
if mean is None or std is None:
raise ValueError(
"MEAN_STD normalization mode requires mean and std stats, please update the dataset with the correct stats"
)
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
@@ -330,7 +337,14 @@ class _NormalizationMixin:
return tensor * std + mean
return (tensor - mean) / denom
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
if norm_mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min", None)
max_val = stats.get("max", None)
if min_val is None or max_val is None:
raise ValueError(
"MIN_MAX normalization mode requires min and max stats, please update the dataset with the correct stats"
)
min_val, max_val = stats["min"], stats["max"]
denom = max_val - min_val
# When min_val == max_val, substitute the denominator with a small epsilon
@@ -345,8 +359,14 @@ class _NormalizationMixin:
# Map from [min, max] to [-1, 1]
return 2 * (tensor - min_val) / denom - 1
if norm_mode == NormalizationMode.QUANTILES and "q01" in stats and "q99" in stats:
q01, q99 = stats["q01"], stats["q99"]
if norm_mode == NormalizationMode.QUANTILES:
q01 = stats.get("q01", None)
q99 = stats.get("q99", None)
if q01 is None or q99 is None:
raise ValueError(
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats"
)
denom = q99 - q01
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
@@ -356,8 +376,14 @@ class _NormalizationMixin:
return tensor * denom + q01
return (tensor - q01) / denom
if norm_mode == NormalizationMode.QUANTILE10 and "q10" in stats and "q90" in stats:
q10, q90 = stats["q10"], stats["q90"]
if norm_mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10", None)
q90 = stats.get("q90", None)
if q10 is None or q90 is None:
raise ValueError(
"QUANTILE10 normalization mode requires q10 and q90 stats, please update the dataset with the correct stats"
)
denom = q90 - q10
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(