From 427b97d1988ec9400441a841fd69009a7170567d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Jul 2025 09:42:15 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../processor/migrate_policy_normalization.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 498904143..c909bff56 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -39,8 +39,7 @@ from typing import Any, Dict import torch from huggingface_hub import HfApi, hf_hub_download -from safetensors.torch import load_file as load_safetensors -from safetensors.torch import save_file as save_safetensors +from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.processor.normalize_processor import NormalizerProcessor @@ -60,7 +59,7 @@ POLICY_CLASSES = { } -def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: +def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: """Extract normalization statistics from model state_dict.""" stats = {} @@ -94,8 +93,8 @@ def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str def detect_features_and_norm_modes( - config: Dict[str, Any], stats: Dict[str, Dict[str, torch.Tensor]] -) -> tuple[Dict[str, PolicyFeature], Dict[FeatureType, NormalizationMode]]: + config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]] +) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]: """Detect features and normalization modes from config and stats.""" features = {} norm_modes = {} @@ -187,7 +186,7 @@ def detect_features_and_norm_modes( return features, norm_modes -def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Remove normalization layers from state_dict.""" new_state_dict = {} @@ -210,7 +209,7 @@ def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str return new_state_dict -def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[Dict[str, torch.Tensor], Dict[str, Any]]: +def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: """Load model state_dict and config from hub.""" # Download files safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)