[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-06 19:31:23 +00:00
committed by Adil Zouitine
parent b08149a113
commit 116059a43e
2 changed files with 8 additions and 8 deletions
+7 -7
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Mapping, Optional, Set from typing import Any, Mapping
import numpy as np import numpy as np
import torch import torch
@@ -51,12 +51,12 @@ class NormalizerProcessor:
norm_map: dict[FeatureType, NormalizationMode] norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance. # Pre-computed statistics coming from dataset.meta.stats for instance.
stats: Optional[dict[str, dict[str, Any]]] = None stats: dict[str, dict[str, Any]] | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except # Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes # "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1). # membership checks O(1).
normalize_keys: Optional[Set[str]] = None normalize_keys: set[str] | None = None
eps: float = 1e-8 eps: float = 1e-8
@@ -69,9 +69,9 @@ class NormalizerProcessor:
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode], norm_map: dict[FeatureType, NormalizationMode],
*, *,
normalize_keys: Optional[Set[str]] = None, normalize_keys: set[str] | None = None,
eps: float = 1e-8, eps: float = 1e-8,
) -> "NormalizerProcessor": ) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`. """Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design The features and norm_map parameters are mandatory to match the design
@@ -195,7 +195,7 @@ class UnnormalizerProcessor:
features: dict[str, PolicyFeature] features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode] norm_map: dict[FeatureType, NormalizationMode]
stats: Optional[dict[str, dict[str, Any]]] = None stats: dict[str, dict[str, Any]] | None = None
eps: float = 1e-8 eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -208,7 +208,7 @@ class UnnormalizerProcessor:
norm_map: dict[FeatureType, NormalizationMode], norm_map: dict[FeatureType, NormalizationMode],
*, *,
eps: float = 1e-8, eps: float = 1e-8,
) -> "UnnormalizerProcessor": ) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps) return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
def __post_init__(self): def __post_init__(self):
+1 -1
View File
@@ -223,7 +223,7 @@ def test_complex_nested_observation():
reconstructed_batch = _default_transition_to_batch(transition) reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved # Check that all observation keys are preserved
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")} original_obs_keys = {k for k in batch if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")} reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys assert original_obs_keys == reconstructed_obs_keys