mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
clamp quantiles
This commit is contained in:
@@ -375,8 +375,11 @@ class _NormalizationMixin:
|
|||||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||||
)
|
)
|
||||||
if inverse:
|
if inverse:
|
||||||
|
tensor = torch.clamp(tensor, -1.0, 1.0)
|
||||||
return tensor * denom + q01
|
return tensor * denom + q01
|
||||||
return 2.0 * (tensor - q01) / denom - 1.0
|
result = 2.0 * (tensor - q01) / denom - 1.0
|
||||||
|
result = torch.clamp(result, -1.0, 1.0)
|
||||||
|
return result
|
||||||
|
|
||||||
if norm_mode == NormalizationMode.QUANTILE10:
|
if norm_mode == NormalizationMode.QUANTILE10:
|
||||||
q10 = stats.get("q10", None)
|
q10 = stats.get("q10", None)
|
||||||
@@ -392,8 +395,11 @@ class _NormalizationMixin:
|
|||||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||||
)
|
)
|
||||||
if inverse:
|
if inverse:
|
||||||
|
tensor = torch.clamp(tensor, -1.0, 1.0)
|
||||||
return tensor * denom + q10
|
return tensor * denom + q10
|
||||||
return 2.0 * (tensor - q10) / denom - 1.0
|
result = 2.0 * (tensor - q10) / denom - 1.0
|
||||||
|
result = torch.clamp(result, -1.0, 1.0)
|
||||||
|
return result
|
||||||
|
|
||||||
# If necessary stats are missing, return input unchanged.
|
# If necessary stats are missing, return input unchanged.
|
||||||
return tensor
|
return tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user