mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
b08149a113
commit
116059a43e
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Mapping, Optional, Set
|
from typing import Any, Mapping
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -51,12 +51,12 @@ class NormalizerProcessor:
|
|||||||
norm_map: dict[FeatureType, NormalizationMode]
|
norm_map: dict[FeatureType, NormalizationMode]
|
||||||
|
|
||||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
stats: dict[str, dict[str, Any]] | None = None
|
||||||
|
|
||||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||||
# membership checks O(1).
|
# membership checks O(1).
|
||||||
normalize_keys: Optional[Set[str]] = None
|
normalize_keys: set[str] | None = None
|
||||||
|
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
@@ -69,9 +69,9 @@ class NormalizerProcessor:
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[FeatureType, NormalizationMode],
|
norm_map: dict[FeatureType, NormalizationMode],
|
||||||
*,
|
*,
|
||||||
normalize_keys: Optional[Set[str]] = None,
|
normalize_keys: set[str] | None = None,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> "NormalizerProcessor":
|
) -> NormalizerProcessor:
|
||||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||||
|
|
||||||
The features and norm_map parameters are mandatory to match the design
|
The features and norm_map parameters are mandatory to match the design
|
||||||
@@ -195,7 +195,7 @@ class UnnormalizerProcessor:
|
|||||||
|
|
||||||
features: dict[str, PolicyFeature]
|
features: dict[str, PolicyFeature]
|
||||||
norm_map: dict[FeatureType, NormalizationMode]
|
norm_map: dict[FeatureType, NormalizationMode]
|
||||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
stats: dict[str, dict[str, Any]] | None = None
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
@@ -208,7 +208,7 @@ class UnnormalizerProcessor:
|
|||||||
norm_map: dict[FeatureType, NormalizationMode],
|
norm_map: dict[FeatureType, NormalizationMode],
|
||||||
*,
|
*,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> "UnnormalizerProcessor":
|
) -> UnnormalizerProcessor:
|
||||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ def test_complex_nested_observation():
|
|||||||
reconstructed_batch = _default_transition_to_batch(transition)
|
reconstructed_batch = _default_transition_to_batch(transition)
|
||||||
|
|
||||||
# Check that all observation keys are preserved
|
# Check that all observation keys are preserved
|
||||||
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")}
|
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||||
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
|
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
|
||||||
|
|
||||||
assert original_obs_keys == reconstructed_obs_keys
|
assert original_obs_keys == reconstructed_obs_keys
|
||||||
|
|||||||
Reference in New Issue
Block a user