mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
feat(normalization): Implement IDENTITY mode for normalization and unnormalization
- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes.
This commit is contained in:
committed by
Steven Palma
parent
c0013b130b
commit
fbe9009db2
@@ -128,7 +128,19 @@ class NormalizerProcessor:
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
@@ -139,16 +151,32 @@ class NormalizerProcessor:
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
@@ -157,13 +185,20 @@ class NormalizerProcessor:
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
@@ -259,8 +294,21 @@ class UnnormalizerProcessor:
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
@@ -268,30 +316,55 @@ class UnnormalizerProcessor:
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
|
||||
Reference in New Issue
Block a user