mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
b08149a113
commit
116059a43e
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping, Optional, Set
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -51,12 +51,12 @@ class NormalizerProcessor:
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: Optional[Set[str]] = None
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
eps: float = 1e-8
|
||||
|
||||
@@ -69,9 +69,9 @@ class NormalizerProcessor:
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: Optional[Set[str]] = None,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> "NormalizerProcessor":
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
@@ -195,7 +195,7 @@ class UnnormalizerProcessor:
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
eps: float = 1e-8
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
@@ -208,7 +208,7 @@ class UnnormalizerProcessor:
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
eps: float = 1e-8,
|
||||
) -> "UnnormalizerProcessor":
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -223,7 +223,7 @@ def test_complex_nested_observation():
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")}
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
|
||||
|
||||
assert original_obs_keys == reconstructed_obs_keys
|
||||
|
||||
Reference in New Issue
Block a user