[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-22 09:42:15 +00:00
committed by Steven Palma
parent 2c2bb1e8bf
commit 427b97d198
@@ -39,8 +39,7 @@ from typing import Any, Dict
import torch
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file as load_safetensors
from safetensors.torch import save_file as save_safetensors
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.normalize_processor import NormalizerProcessor
@@ -60,7 +59,7 @@ POLICY_CLASSES = {
}
def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
stats = {}
@@ -94,8 +93,8 @@ def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str
def detect_features_and_norm_modes(
config: Dict[str, Any], stats: Dict[str, Dict[str, torch.Tensor]]
) -> tuple[Dict[str, PolicyFeature], Dict[FeatureType, NormalizationMode]]:
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
features = {}
norm_modes = {}
@@ -187,7 +186,7 @@ def detect_features_and_norm_modes(
return features, norm_modes
def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
new_state_dict = {}
@@ -210,7 +209,7 @@ def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str
return new_state_dict
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)