mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
check normalization before applying
This commit is contained in:
@@ -322,7 +322,14 @@ class _NormalizationMixin:
|
||||
|
||||
stats = self._tensor_stats[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
if norm_mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean", None)
|
||||
std = stats.get("std", None)
|
||||
if mean is None or std is None:
|
||||
raise ValueError(
|
||||
"MEAN_STD normalization mode requires mean and std stats, please update the dataset with the correct stats"
|
||||
)
|
||||
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
@@ -330,7 +337,14 @@ class _NormalizationMixin:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
|
||||
if norm_mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min", None)
|
||||
max_val = stats.get("max", None)
|
||||
if min_val is None or max_val is None:
|
||||
raise ValueError(
|
||||
"MIN_MAX normalization mode requires min and max stats, please update the dataset with the correct stats"
|
||||
)
|
||||
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
denom = max_val - min_val
|
||||
# When min_val == max_val, substitute the denominator with a small epsilon
|
||||
@@ -345,8 +359,14 @@ class _NormalizationMixin:
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
if norm_mode == NormalizationMode.QUANTILES and "q01" in stats and "q99" in stats:
|
||||
q01, q99 = stats["q01"], stats["q99"]
|
||||
if norm_mode == NormalizationMode.QUANTILES:
|
||||
q01 = stats.get("q01", None)
|
||||
q99 = stats.get("q99", None)
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError(
|
||||
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats"
|
||||
)
|
||||
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
@@ -356,8 +376,14 @@ class _NormalizationMixin:
|
||||
return tensor * denom + q01
|
||||
return (tensor - q01) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.QUANTILE10 and "q10" in stats and "q90" in stats:
|
||||
q10, q90 = stats["q10"], stats["q90"]
|
||||
if norm_mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10", None)
|
||||
q90 = stats.get("q90", None)
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError(
|
||||
"QUANTILE10 normalization mode requires q10 and q90 stats, please update the dataset with the correct stats"
|
||||
)
|
||||
|
||||
denom = q90 - q10
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
|
||||
Reference in New Issue
Block a user