mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
Refactor normalization components and update tests
- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes.
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .normalize_processor import NormalizationProcessor
|
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from .observation_processor import (
|
from .observation_processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
@@ -38,7 +38,8 @@ __all__ = [
|
|||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
"ImageProcessor",
|
"ImageProcessor",
|
||||||
"InfoProcessor",
|
"InfoProcessor",
|
||||||
"NormalizationProcessor",
|
"NormalizerProcessor",
|
||||||
|
"UnnormalizerProcessor",
|
||||||
"ObservationProcessor",
|
"ObservationProcessor",
|
||||||
"ProcessorStep",
|
"ProcessorStep",
|
||||||
"RenameProcessor",
|
"RenameProcessor",
|
||||||
|
|||||||
@@ -30,31 +30,22 @@ def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dic
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="observation_normalizer")
|
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||||
class ObservationNormalizer:
|
class NormalizerProcessor:
|
||||||
"""Normalize observations using dataset statistics.
|
"""Normalize observations *and* actions in one go.
|
||||||
|
|
||||||
This processor normalizes selected observation keys using either:
|
This is a thin convenience wrapper equivalent to::
|
||||||
- Standard normalization: ``(x - mean) / (std + eps)``
|
|
||||||
- Min-Max normalization to [-1, 1]: ``2 * (x - min) / (max - min + eps) - 1``
|
|
||||||
|
|
||||||
Parameters
|
proc = RobotProcessor([ObservationNormalizer(stats, ...), ActionNormalizer(action_stats, ...)])
|
||||||
----------
|
|
||||||
stats : Dict[str, Dict[str, np.ndarray | Tensor]]
|
Keeping it as a single step is handy for profiling and simplifies
|
||||||
Dataset statistics. Each entry must provide either
|
configuration files.
|
||||||
``{"mean", "std"}`` or ``{"min", "max"}``.
|
|
||||||
normalize_keys : set[str] | None, default=None
|
|
||||||
Observation keys to normalize. ``None`` means all keys
|
|
||||||
present in both the observation and stats.
|
|
||||||
eps : float, default=1e-8
|
|
||||||
Small constant to avoid division by zero.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stats: dict[str, dict[str, Any]]
|
stats: dict[str, dict[str, Any]]
|
||||||
normalize_keys: set[str] | None = None
|
normalize_keys: set[str] | None = None
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
# Cached tensors for performance
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -64,70 +55,66 @@ class ObservationNormalizer:
|
|||||||
*,
|
*,
|
||||||
normalize_keys: set[str] | None = None,
|
normalize_keys: set[str] | None = None,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> ObservationNormalizer:
|
) -> NormalizerProcessor:
|
||||||
"""Create from a LeRobotDataset."""
|
return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps)
|
||||||
# Filter stats to only include observation keys
|
|
||||||
obs_stats = {k: v for k, v in dataset.meta.stats.items() if k != "action"}
|
|
||||||
return cls(stats=obs_stats, normalize_keys=normalize_keys, eps=eps)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def _normalize_obs(self, observation):
|
||||||
observation = transition[TransitionIndex.OBSERVATION]
|
|
||||||
|
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return transition
|
return None
|
||||||
|
|
||||||
# Determine which keys to normalize
|
|
||||||
keys_to_norm = (
|
keys_to_norm = (
|
||||||
self.normalize_keys if self.normalize_keys is not None else set(self._tensor_stats.keys())
|
self.normalize_keys
|
||||||
|
if self.normalize_keys is not None
|
||||||
|
else {k for k in self._tensor_stats if k != "action"}
|
||||||
)
|
)
|
||||||
|
processed = dict(observation)
|
||||||
# Create a copy to avoid mutating input
|
|
||||||
processed_obs = dict(observation)
|
|
||||||
|
|
||||||
for key in keys_to_norm:
|
for key in keys_to_norm:
|
||||||
if key not in processed_obs:
|
if key not in processed or key not in self._tensor_stats:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if key not in self._tensor_stats:
|
orig_val = processed[key]
|
||||||
if self.normalize_keys is not None:
|
tensor = (
|
||||||
# User explicitly requested this key but stats are missing
|
orig_val.to(dtype=torch.float32)
|
||||||
raise KeyError(f"Stats not found for requested key '{key}'")
|
if isinstance(orig_val, torch.Tensor)
|
||||||
continue
|
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||||
|
|
||||||
# Convert to tensor if needed
|
|
||||||
orig_val = processed_obs[key]
|
|
||||||
if isinstance(orig_val, torch.Tensor):
|
|
||||||
tensor = orig_val.to(dtype=torch.float32)
|
|
||||||
elif isinstance(orig_val, np.ndarray):
|
|
||||||
tensor = torch.from_numpy(orig_val.astype(np.float32))
|
|
||||||
else:
|
|
||||||
# For lists, tuples, scalars, etc.
|
|
||||||
tensor = torch.as_tensor(orig_val, dtype=torch.float32)
|
|
||||||
|
|
||||||
stats = self._tensor_stats[key]
|
|
||||||
# Move stats to same device as data
|
|
||||||
stats = {k: v.to(device=tensor.device) for k, v in stats.items()}
|
|
||||||
|
|
||||||
# Apply normalization
|
|
||||||
if "mean" in stats and "std" in stats:
|
if "mean" in stats and "std" in stats:
|
||||||
mean, std = stats["mean"], stats["std"]
|
mean, std = stats["mean"], stats["std"]
|
||||||
processed_obs[key] = (tensor - mean) / (std + self.eps)
|
processed[key] = (tensor - mean) / (std + self.eps)
|
||||||
elif "min" in stats and "max" in stats:
|
elif "min" in stats and "max" in stats:
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
# Normalize to [0, 1] then to [-1, 1]
|
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||||
processed_obs[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
return processed
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Stats for key '{key}' must contain either ('mean', 'std') or ('min', 'max')"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return new transition with normalized observation
|
def _normalize_action(self, action):
|
||||||
|
if action is None or "action" not in self._tensor_stats:
|
||||||
|
return action
|
||||||
|
|
||||||
|
tensor = (
|
||||||
|
action.to(dtype=torch.float32)
|
||||||
|
if isinstance(action, torch.Tensor)
|
||||||
|
else torch.as_tensor(action, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||||
|
if "mean" in stats and "std" in stats:
|
||||||
|
mean, std = stats["mean"], stats["std"]
|
||||||
|
return (tensor - mean) / (std + self.eps)
|
||||||
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||||
|
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||||
|
action = self._normalize_action(transition[TransitionIndex.ACTION])
|
||||||
return (
|
return (
|
||||||
processed_obs,
|
observation,
|
||||||
transition[TransitionIndex.ACTION],
|
action,
|
||||||
transition[TransitionIndex.REWARD],
|
transition[TransitionIndex.REWARD],
|
||||||
transition[TransitionIndex.DONE],
|
transition[TransitionIndex.DONE],
|
||||||
transition[TransitionIndex.TRUNCATED],
|
transition[TransitionIndex.TRUNCATED],
|
||||||
@@ -136,149 +123,38 @@ class ObservationNormalizer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {
|
return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps}
|
||||||
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
|
|
||||||
"eps": self.eps,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
flat_state: dict[str, Tensor] = {}
|
flat = {}
|
||||||
for key, sub in self._tensor_stats.items():
|
for key, sub in self._tensor_stats.items():
|
||||||
for stat_name, tensor in sub.items():
|
for stat_name, tensor in sub.items():
|
||||||
flat_state[f"{key}.{stat_name}"] = tensor
|
flat[f"{key}.{stat_name}"] = tensor
|
||||||
return flat_state
|
return flat
|
||||||
|
|
||||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||||
self._tensor_stats.clear()
|
self._tensor_stats.clear()
|
||||||
for flat_key, tensor in state.items():
|
for flat_key, tensor in state.items():
|
||||||
key, stat_name = flat_key.rsplit(".", 1)
|
key, stat_name = flat_key.rsplit(".", 1)
|
||||||
if key not in self._tensor_stats:
|
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||||
self._tensor_stats[key] = {}
|
|
||||||
self._tensor_stats[key][stat_name] = tensor
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self):
|
||||||
"""Nothing to reset for this stateless processor."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="action_unnormalizer")
|
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||||
class ActionUnnormalizer:
|
class UnnormalizerProcessor:
|
||||||
"""Un-normalize actions using dataset statistics.
|
"""Inverse normalisation for observations and actions.
|
||||||
|
|
||||||
This processor un-normalizes actions using the inverse of normalization:
|
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||||
- Standard: ``action * std + mean``
|
transform.
|
||||||
- Min-Max from [-1, 1]: ``(action + 1) / 2 * (max - min) + min``
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
action_stats : Dict[str, np.ndarray | Tensor]
|
|
||||||
Action statistics containing either ``{"mean", "std"}`` or ``{"min", "max"}``.
|
|
||||||
eps : float, default=1e-8
|
|
||||||
Small constant used during normalization (not used in unnormalization).
|
|
||||||
"""
|
|
||||||
|
|
||||||
action_stats: dict[str, Any]
|
|
||||||
eps: float = 1e-8 # Kept for consistency, not used in unnormalization
|
|
||||||
|
|
||||||
# Cached tensors for performance
|
|
||||||
_tensor_stats: dict[str, Tensor] = field(default_factory=dict, init=False, repr=False)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_lerobot_dataset(
|
|
||||||
cls,
|
|
||||||
dataset: LeRobotDataset,
|
|
||||||
*,
|
|
||||||
eps: float = 1e-8,
|
|
||||||
) -> ActionUnnormalizer:
|
|
||||||
"""Create from a LeRobotDataset."""
|
|
||||||
if "action" not in dataset.meta.stats:
|
|
||||||
raise ValueError("Dataset does not contain action statistics")
|
|
||||||
return cls(action_stats=dataset.meta.stats["action"], eps=eps)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Convert action stats to tensors
|
|
||||||
tensor_stats = _convert_stats_to_tensors({"action": self.action_stats})
|
|
||||||
self._tensor_stats = tensor_stats["action"]
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
action = transition[TransitionIndex.ACTION]
|
|
||||||
|
|
||||||
if action is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Convert to tensor if needed
|
|
||||||
if isinstance(action, torch.Tensor):
|
|
||||||
action = action.to(dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
action = torch.as_tensor(action, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Move stats to same device as action
|
|
||||||
stats = {k: v.to(device=action.device) for k, v in self._tensor_stats.items()}
|
|
||||||
|
|
||||||
# Apply unnormalization
|
|
||||||
if "mean" in stats and "std" in stats:
|
|
||||||
mean, std = stats["mean"], stats["std"]
|
|
||||||
unnormalized_action = action * std + mean
|
|
||||||
elif "min" in stats and "max" in stats:
|
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
|
||||||
# Map from [-1, 1] to [0, 1] then to [min, max]
|
|
||||||
unnormalized_action = (action + 1) / 2 * (max_val - min_val) + min_val
|
|
||||||
else:
|
|
||||||
raise ValueError("Action stats must contain either ('mean', 'std') or ('min', 'max')")
|
|
||||||
|
|
||||||
# Return new transition with unnormalized action
|
|
||||||
return (
|
|
||||||
transition[TransitionIndex.OBSERVATION],
|
|
||||||
unnormalized_action,
|
|
||||||
transition[TransitionIndex.REWARD],
|
|
||||||
transition[TransitionIndex.DONE],
|
|
||||||
transition[TransitionIndex.TRUNCATED],
|
|
||||||
transition[TransitionIndex.INFO],
|
|
||||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {"eps": self.eps}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
|
||||||
return dict(self._tensor_stats.items())
|
|
||||||
|
|
||||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
|
||||||
self._tensor_stats = dict(state)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Nothing to reset for this stateless processor."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register(name="normalization_processor")
|
|
||||||
class NormalizationProcessor:
|
|
||||||
"""Combined processor that normalizes observations and/or un-normalizes actions.
|
|
||||||
|
|
||||||
This processor combines the functionality of ObservationNormalizer and
|
|
||||||
ActionUnnormalizer for convenience when both operations are needed.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
stats : Dict[str, Dict[str, np.ndarray | Tensor]]
|
|
||||||
Dataset statistics as returned by ``LeRobotDataset.meta.stats``.
|
|
||||||
normalize_keys : set[str] | None, default=None
|
|
||||||
Observation keys to normalize. ``None`` means all keys
|
|
||||||
present in both the observation and stats.
|
|
||||||
unnormalize_action : bool, default=True
|
|
||||||
Whether to un-normalize actions.
|
|
||||||
eps : float, default=1e-8
|
|
||||||
Small constant to avoid division by zero.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stats: dict[str, dict[str, Any]]
|
stats: dict[str, dict[str, Any]]
|
||||||
normalize_keys: set[str] | None = None
|
unnormalize_keys: set[str] | None = None
|
||||||
unnormalize_action: bool = True
|
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
# Cached tensors for performance
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -286,75 +162,61 @@ class NormalizationProcessor:
|
|||||||
cls,
|
cls,
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
*,
|
*,
|
||||||
normalize_keys: set[str] | None = None,
|
unnormalize_keys: set[str] | None = None,
|
||||||
unnormalize_action: bool = True,
|
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> NormalizationProcessor:
|
) -> UnnormalizerProcessor:
|
||||||
"""Create from a LeRobotDataset."""
|
return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps)
|
||||||
return cls(
|
|
||||||
stats=dataset.meta.stats,
|
|
||||||
normalize_keys=normalize_keys,
|
|
||||||
unnormalize_action=unnormalize_action,
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def _unnormalize_obs(self, observation):
|
||||||
observation = transition[TransitionIndex.OBSERVATION]
|
if observation is None:
|
||||||
action = transition[TransitionIndex.ACTION]
|
return None
|
||||||
|
keys = (
|
||||||
# Normalize observations
|
self.unnormalize_keys
|
||||||
if observation is not None:
|
if self.unnormalize_keys is not None
|
||||||
processed_obs = dict(observation)
|
else {k for k in self._tensor_stats if k != "action"}
|
||||||
keys_to_norm = (
|
)
|
||||||
self.normalize_keys
|
processed = dict(observation)
|
||||||
if self.normalize_keys is not None
|
for key in keys:
|
||||||
else {k for k in self._tensor_stats if k != "action"}
|
if key not in processed or key not in self._tensor_stats:
|
||||||
|
continue
|
||||||
|
orig_val = processed[key]
|
||||||
|
tensor = (
|
||||||
|
orig_val.to(dtype=torch.float32)
|
||||||
|
if isinstance(orig_val, torch.Tensor)
|
||||||
|
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||||
)
|
)
|
||||||
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||||
for key in keys_to_norm:
|
|
||||||
if key not in processed_obs or key not in self._tensor_stats:
|
|
||||||
continue
|
|
||||||
|
|
||||||
orig_val = processed_obs[key]
|
|
||||||
if isinstance(orig_val, torch.Tensor):
|
|
||||||
tensor = orig_val.to(dtype=torch.float32)
|
|
||||||
elif isinstance(orig_val, np.ndarray):
|
|
||||||
tensor = torch.from_numpy(orig_val.astype(np.float32))
|
|
||||||
else:
|
|
||||||
tensor = torch.as_tensor(orig_val, dtype=torch.float32)
|
|
||||||
|
|
||||||
stats = self._tensor_stats[key]
|
|
||||||
stats = {k: v.to(device=tensor.device) for k, v in stats.items()}
|
|
||||||
|
|
||||||
if "mean" in stats and "std" in stats:
|
|
||||||
mean, std = stats["mean"], stats["std"]
|
|
||||||
processed_obs[key] = (tensor - mean) / (std + self.eps)
|
|
||||||
elif "min" in stats and "max" in stats:
|
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
|
||||||
processed_obs[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
|
||||||
|
|
||||||
observation = processed_obs
|
|
||||||
|
|
||||||
# Un-normalize action
|
|
||||||
if self.unnormalize_action and action is not None and "action" in self._tensor_stats:
|
|
||||||
if isinstance(action, torch.Tensor):
|
|
||||||
action = action.to(dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
action = torch.as_tensor(action, dtype=torch.float32)
|
|
||||||
|
|
||||||
stats = {k: v.to(device=action.device) for k, v in self._tensor_stats["action"].items()}
|
|
||||||
|
|
||||||
if "mean" in stats and "std" in stats:
|
if "mean" in stats and "std" in stats:
|
||||||
mean, std = stats["mean"], stats["std"]
|
mean, std = stats["mean"], stats["std"]
|
||||||
action = action * std + mean
|
processed[key] = tensor * std + mean
|
||||||
elif "min" in stats and "max" in stats:
|
elif "min" in stats and "max" in stats:
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
action = (action + 1) / 2 * (max_val - min_val) + min_val
|
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||||
|
return processed
|
||||||
|
|
||||||
# Return new transition
|
def _unnormalize_action(self, action):
|
||||||
|
if action is None or "action" not in self._tensor_stats:
|
||||||
|
return action
|
||||||
|
tensor = (
|
||||||
|
action.to(dtype=torch.float32)
|
||||||
|
if isinstance(action, torch.Tensor)
|
||||||
|
else torch.as_tensor(action, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||||
|
if "mean" in stats and "std" in stats:
|
||||||
|
mean, std = stats["mean"], stats["std"]
|
||||||
|
return tensor * std + mean
|
||||||
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||||
|
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||||
|
action = self._unnormalize_action(transition[TransitionIndex.ACTION])
|
||||||
return (
|
return (
|
||||||
observation,
|
observation,
|
||||||
action,
|
action,
|
||||||
@@ -367,26 +229,22 @@ class NormalizationProcessor:
|
|||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
|
"unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None,
|
||||||
"unnormalize_action": self.unnormalize_action,
|
|
||||||
"eps": self.eps,
|
"eps": self.eps,
|
||||||
}
|
}
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
flat_state: dict[str, Tensor] = {}
|
flat = {}
|
||||||
for key, sub in self._tensor_stats.items():
|
for key, sub in self._tensor_stats.items():
|
||||||
for stat_name, tensor in sub.items():
|
for stat_name, tensor in sub.items():
|
||||||
flat_state[f"{key}.{stat_name}"] = tensor
|
flat[f"{key}.{stat_name}"] = tensor
|
||||||
return flat_state
|
return flat
|
||||||
|
|
||||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||||
self._tensor_stats.clear()
|
self._tensor_stats.clear()
|
||||||
for flat_key, tensor in state.items():
|
for flat_key, tensor in state.items():
|
||||||
key, stat_name = flat_key.rsplit(".", 1)
|
key, stat_name = flat_key.rsplit(".", 1)
|
||||||
if key not in self._tensor_stats:
|
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||||
self._tensor_stats[key] = {}
|
|
||||||
self._tensor_stats[key][stat_name] = tensor
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self):
|
||||||
"""Nothing to reset for this stateless processor."""
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -5,9 +5,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.processor.normalize_processor import (
|
from lerobot.processor.normalize_processor import (
|
||||||
ActionUnnormalizer,
|
NormalizerProcessor,
|
||||||
NormalizationProcessor,
|
UnnormalizerProcessor,
|
||||||
ObservationNormalizer,
|
|
||||||
_convert_stats_to_tensors,
|
_convert_stats_to_tensors,
|
||||||
)
|
)
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||||
@@ -77,7 +76,7 @@ def test_unsupported_type():
|
|||||||
_convert_stats_to_tensors(stats)
|
_convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for ObservationNormalizer tests
|
# Fixtures for observation normalisation tests using NormalizerProcessor
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def observation_stats():
|
def observation_stats():
|
||||||
return {
|
return {
|
||||||
@@ -94,7 +93,8 @@ def observation_stats():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def observation_normalizer(observation_stats):
|
def observation_normalizer(observation_stats):
|
||||||
return ObservationNormalizer(stats=observation_stats)
|
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
||||||
|
return NormalizerProcessor(stats=observation_stats)
|
||||||
|
|
||||||
|
|
||||||
def test_mean_std_normalization(observation_normalizer):
|
def test_mean_std_normalization(observation_normalizer):
|
||||||
@@ -129,7 +129,7 @@ def test_min_max_normalization(observation_normalizer):
|
|||||||
|
|
||||||
|
|
||||||
def test_selective_normalization(observation_stats):
|
def test_selective_normalization(observation_stats):
|
||||||
normalizer = ObservationNormalizer(stats=observation_stats, normalize_keys={"observation.image"})
|
normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"})
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
@@ -146,46 +146,9 @@ def test_selective_normalization(observation_stats):
|
|||||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||||
|
|
||||||
|
|
||||||
def test_missing_stats_error(observation_stats):
|
|
||||||
normalizer = ObservationNormalizer(
|
|
||||||
stats={"observation.image": observation_stats["observation.image"]},
|
|
||||||
normalize_keys={"observation.image", "observation.missing"},
|
|
||||||
)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.5, 0.5, 0.5]),
|
|
||||||
"observation.missing": torch.tensor([1.0, 2.0]),
|
|
||||||
}
|
|
||||||
transition = (observation, None, None, None, None, None, None)
|
|
||||||
|
|
||||||
with pytest.raises(KeyError, match="Stats not found for requested key 'observation.missing'"):
|
|
||||||
normalizer(transition)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"input_type,input_value,expected_type",
|
|
||||||
[
|
|
||||||
("numpy", np.array([0.7, 0.5, 0.3], dtype=np.float32), torch.Tensor),
|
|
||||||
("torch", torch.tensor([0.7, 0.5, 0.3]), torch.Tensor),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_input_types(observation_normalizer, input_type, input_value, expected_type):
|
|
||||||
observation = {
|
|
||||||
"observation.image": input_value,
|
|
||||||
}
|
|
||||||
transition = (observation, None, None, None, None, None, None)
|
|
||||||
|
|
||||||
normalized_transition = observation_normalizer(transition)
|
|
||||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
|
||||||
|
|
||||||
expected = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
|
||||||
assert isinstance(normalized_obs["observation.image"], expected_type)
|
|
||||||
assert torch.allclose(normalized_obs["observation.image"], expected)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
def test_device_compatibility(observation_stats):
|
def test_device_compatibility(observation_stats):
|
||||||
normalizer = ObservationNormalizer(stats=observation_stats)
|
normalizer = NormalizerProcessor(stats=observation_stats)
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||||
}
|
}
|
||||||
@@ -205,11 +168,11 @@ def test_from_lerobot_dataset():
|
|||||||
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out
|
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizer = ObservationNormalizer.from_lerobot_dataset(mock_dataset)
|
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||||
|
|
||||||
# Check that action stats are filtered out
|
# Both observation and action statistics should be present in tensor stats
|
||||||
assert "observation.image" in normalizer._tensor_stats
|
assert "observation.image" in normalizer._tensor_stats
|
||||||
assert "action" not in normalizer._tensor_stats
|
assert "action" in normalizer._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_save_load(observation_normalizer):
|
def test_state_dict_save_load(observation_normalizer):
|
||||||
@@ -217,7 +180,7 @@ def test_state_dict_save_load(observation_normalizer):
|
|||||||
state_dict = observation_normalizer.state_dict()
|
state_dict = observation_normalizer.state_dict()
|
||||||
|
|
||||||
# Create new normalizer and load state
|
# Create new normalizer and load state
|
||||||
new_normalizer = ObservationNormalizer(stats={})
|
new_normalizer = NormalizerProcessor(stats={})
|
||||||
new_normalizer.load_state_dict(state_dict)
|
new_normalizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
# Test that it works the same
|
# Test that it works the same
|
||||||
@@ -248,7 +211,7 @@ def action_stats_min_max():
|
|||||||
|
|
||||||
|
|
||||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||||
|
|
||||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||||
transition = (None, normalized_action, None, None, None, None, None)
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
@@ -262,7 +225,7 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
|||||||
|
|
||||||
|
|
||||||
def test_min_max_unnormalization(action_stats_min_max):
|
def test_min_max_unnormalization(action_stats_min_max):
|
||||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_min_max)
|
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max})
|
||||||
|
|
||||||
# Actions in [-1, 1]
|
# Actions in [-1, 1]
|
||||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||||
@@ -284,7 +247,7 @@ def test_min_max_unnormalization(action_stats_min_max):
|
|||||||
|
|
||||||
|
|
||||||
def test_numpy_action_input(action_stats_mean_std):
|
def test_numpy_action_input(action_stats_mean_std):
|
||||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||||
|
|
||||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||||
transition = (None, normalized_action, None, None, None, None, None)
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
@@ -298,7 +261,7 @@ def test_numpy_action_input(action_stats_mean_std):
|
|||||||
|
|
||||||
|
|
||||||
def test_none_action(action_stats_mean_std):
|
def test_none_action(action_stats_mean_std):
|
||||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||||
|
|
||||||
transition = (None, None, None, None, None, None, None)
|
transition = (None, None, None, None, None, None, None)
|
||||||
result = unnormalizer(transition)
|
result = unnormalizer(transition)
|
||||||
@@ -308,40 +271,13 @@ def test_none_action(action_stats_mean_std):
|
|||||||
|
|
||||||
|
|
||||||
def test_action_from_lerobot_dataset():
|
def test_action_from_lerobot_dataset():
|
||||||
# Mock dataset
|
|
||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = {
|
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||||
"action": {"mean": [0.0], "std": [1.0]},
|
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||||
}
|
|
||||||
|
|
||||||
unnormalizer = ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
|
||||||
|
|
||||||
assert "mean" in unnormalizer._tensor_stats
|
|
||||||
assert "std" in unnormalizer._tensor_stats
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_action_stats_error():
|
# Fixtures for NormalizerProcessor tests
|
||||||
mock_dataset = Mock()
|
|
||||||
mock_dataset.meta.stats = {
|
|
||||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Dataset does not contain action statistics"):
|
|
||||||
ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_stats_error():
|
|
||||||
unnormalizer = ActionUnnormalizer(action_stats={"invalid": [1.0]})
|
|
||||||
|
|
||||||
action = torch.tensor([1.0])
|
|
||||||
transition = (None, action, None, None, None, None, None)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Action stats must contain"):
|
|
||||||
unnormalizer(transition)
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for NormalizationProcessor tests
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def full_stats():
|
def full_stats():
|
||||||
return {
|
return {
|
||||||
@@ -361,11 +297,11 @@ def full_stats():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def normalization_processor(full_stats):
|
def normalizer_processor(full_stats):
|
||||||
return NormalizationProcessor(stats=full_stats)
|
return NormalizerProcessor(stats=full_stats)
|
||||||
|
|
||||||
|
|
||||||
def test_combined_normalization_unnormalization(normalization_processor):
|
def test_combined_normalization(normalizer_processor):
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
@@ -373,16 +309,16 @@ def test_combined_normalization_unnormalization(normalization_processor):
|
|||||||
action = torch.tensor([1.0, -0.5])
|
action = torch.tensor([1.0, -0.5])
|
||||||
transition = (observation, action, 1.0, False, False, {}, {})
|
transition = (observation, action, 1.0, False, False, {}, {})
|
||||||
|
|
||||||
processed_transition = normalization_processor(transition)
|
processed_transition = normalizer_processor(transition)
|
||||||
|
|
||||||
# Check normalized observations
|
# Check normalized observations
|
||||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||||
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
||||||
|
|
||||||
# Check unnormalized action
|
# Check normalized action
|
||||||
processed_action = processed_transition[TransitionIndex.ACTION]
|
processed_action = processed_transition[TransitionIndex.ACTION]
|
||||||
expected_action = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0])
|
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
|
||||||
assert torch.allclose(processed_action, expected_action)
|
assert torch.allclose(processed_action, expected_action)
|
||||||
|
|
||||||
# Check other fields remain unchanged
|
# Check other fields remain unchanged
|
||||||
@@ -390,45 +326,28 @@ def test_combined_normalization_unnormalization(normalization_processor):
|
|||||||
assert not processed_transition[TransitionIndex.DONE]
|
assert not processed_transition[TransitionIndex.DONE]
|
||||||
|
|
||||||
|
|
||||||
def test_disable_action_unnormalization(full_stats):
|
|
||||||
processor = NormalizationProcessor(stats=full_stats, unnormalize_action=False)
|
|
||||||
|
|
||||||
action = torch.tensor([1.0, -0.5])
|
|
||||||
transition = (None, action, None, None, None, None, None)
|
|
||||||
|
|
||||||
processed_transition = processor(transition)
|
|
||||||
|
|
||||||
# Action should remain unchanged
|
|
||||||
assert torch.allclose(processed_transition[TransitionIndex.ACTION], action)
|
|
||||||
|
|
||||||
|
|
||||||
def test_processor_from_lerobot_dataset(full_stats):
|
def test_processor_from_lerobot_dataset(full_stats):
|
||||||
# Mock dataset
|
# Mock dataset
|
||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = full_stats
|
mock_dataset.meta.stats = full_stats
|
||||||
|
|
||||||
processor = NormalizationProcessor.from_lerobot_dataset(
|
processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"})
|
||||||
mock_dataset, normalize_keys={"observation.image"}, unnormalize_action=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert processor.normalize_keys == {"observation.image"}
|
assert processor.normalize_keys == {"observation.image"}
|
||||||
assert processor.unnormalize_action
|
|
||||||
assert "observation.image" in processor._tensor_stats
|
assert "observation.image" in processor._tensor_stats
|
||||||
assert "action" in processor._tensor_stats
|
assert "action" in processor._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
def test_get_config(full_stats):
|
def test_get_config(full_stats):
|
||||||
processor = NormalizationProcessor(
|
processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6)
|
||||||
stats=full_stats, normalize_keys={"observation.image"}, unnormalize_action=False, eps=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
config = processor.get_config()
|
config = processor.get_config()
|
||||||
assert config == {"normalize_keys": ["observation.image"], "unnormalize_action": False, "eps": 1e-6}
|
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
|
||||||
|
|
||||||
|
|
||||||
def test_integration_with_robot_processor(normalization_processor):
|
def test_integration_with_robot_processor(normalizer_processor):
|
||||||
"""Test integration with RobotProcessor pipeline"""
|
"""Test integration with RobotProcessor pipeline"""
|
||||||
robot_processor = RobotProcessor([normalization_processor])
|
robot_processor = RobotProcessor([normalizer_processor])
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
@@ -447,7 +366,7 @@ def test_integration_with_robot_processor(normalization_processor):
|
|||||||
# Edge case tests
|
# Edge case tests
|
||||||
def test_empty_observation():
|
def test_empty_observation():
|
||||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
normalizer = ObservationNormalizer(stats=stats)
|
normalizer = NormalizerProcessor(stats=stats)
|
||||||
|
|
||||||
transition = (None, None, None, None, None, None, None)
|
transition = (None, None, None, None, None, None, None)
|
||||||
result = normalizer(transition)
|
result = normalizer(transition)
|
||||||
@@ -456,7 +375,7 @@ def test_empty_observation():
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_stats():
|
def test_empty_stats():
|
||||||
normalizer = ObservationNormalizer(stats={})
|
normalizer = NormalizerProcessor(stats={})
|
||||||
observation = {"observation.image": torch.tensor([0.5])}
|
observation = {"observation.image": torch.tensor([0.5])}
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
@@ -466,12 +385,20 @@ def test_empty_stats():
|
|||||||
|
|
||||||
|
|
||||||
def test_partial_stats():
|
def test_partial_stats():
|
||||||
stats = {
|
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||||
"observation.image": {"mean": [0.5]}, # Missing std
|
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||||
}
|
normalizer = NormalizerProcessor(stats=stats)
|
||||||
normalizer = ObservationNormalizer(stats=stats)
|
|
||||||
observation = {"observation.image": torch.tensor([0.7])}
|
observation = {"observation.image": torch.tensor([0.7])}
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="must contain either"):
|
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
|
||||||
normalizer(transition)
|
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_action_stats_no_error():
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
|
|
||||||
|
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||||
|
# The tensor stats should not contain the 'action' key
|
||||||
|
assert "action" not in processor._tensor_stats
|
||||||
|
|||||||
Reference in New Issue
Block a user