[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-06 19:31:23 +00:00
committed by Adil Zouitine
parent b08149a113
commit 116059a43e
2 changed files with 8 additions and 8 deletions
+7 -7
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Mapping, Optional, Set
from typing import Any, Mapping
import numpy as np
import torch
@@ -51,12 +51,12 @@ class NormalizerProcessor:
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: Optional[dict[str, dict[str, Any]]] = None
stats: dict[str, dict[str, Any]] | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: Optional[Set[str]] = None
normalize_keys: set[str] | None = None
eps: float = 1e-8
@@ -69,9 +69,9 @@ class NormalizerProcessor:
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: Optional[Set[str]] = None,
normalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> "NormalizerProcessor":
) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
@@ -195,7 +195,7 @@ class UnnormalizerProcessor:
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: Optional[dict[str, dict[str, Any]]] = None
stats: dict[str, dict[str, Any]] | None = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -208,7 +208,7 @@ class UnnormalizerProcessor:
norm_map: dict[FeatureType, NormalizationMode],
*,
eps: float = 1e-8,
) -> "UnnormalizerProcessor":
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
def __post_init__(self):
+1 -1
View File
@@ -223,7 +223,7 @@ def test_complex_nested_observation():
reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")}
original_obs_keys = {k for k in batch if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys