mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Steven Palma
parent
2c2bb1e8bf
commit
427b97d198
@@ -39,8 +39,7 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from safetensors.torch import save_file as save_safetensors
|
||||
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor
|
||||
@@ -60,7 +59,7 @@ POLICY_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
stats = {}
|
||||
|
||||
@@ -94,8 +93,8 @@ def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str
|
||||
|
||||
|
||||
def detect_features_and_norm_modes(
|
||||
config: Dict[str, Any], stats: Dict[str, Dict[str, torch.Tensor]]
|
||||
) -> tuple[Dict[str, PolicyFeature], Dict[FeatureType, NormalizationMode]]:
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
@@ -187,7 +186,7 @@ def detect_features_and_norm_modes(
|
||||
return features, norm_modes
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
new_state_dict = {}
|
||||
|
||||
@@ -210,7 +209,7 @@ def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
Reference in New Issue
Block a user