mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Steven Palma
parent
2c2bb1e8bf
commit
427b97d198
@@ -39,8 +39,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
|
||||||
from safetensors.torch import save_file as save_safetensors
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.processor.normalize_processor import NormalizerProcessor
|
from lerobot.processor.normalize_processor import NormalizerProcessor
|
||||||
@@ -60,7 +59,7 @@ POLICY_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
"""Extract normalization statistics from model state_dict."""
|
"""Extract normalization statistics from model state_dict."""
|
||||||
stats = {}
|
stats = {}
|
||||||
|
|
||||||
@@ -94,8 +93,8 @@ def extract_normalization_stats(state_dict: Dict[str, torch.Tensor]) -> Dict[str
|
|||||||
|
|
||||||
|
|
||||||
def detect_features_and_norm_modes(
|
def detect_features_and_norm_modes(
|
||||||
config: Dict[str, Any], stats: Dict[str, Dict[str, torch.Tensor]]
|
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||||
) -> tuple[Dict[str, PolicyFeature], Dict[FeatureType, NormalizationMode]]:
|
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||||
"""Detect features and normalization modes from config and stats."""
|
"""Detect features and normalization modes from config and stats."""
|
||||||
features = {}
|
features = {}
|
||||||
norm_modes = {}
|
norm_modes = {}
|
||||||
@@ -187,7 +186,7 @@ def detect_features_and_norm_modes(
|
|||||||
return features, norm_modes
|
return features, norm_modes
|
||||||
|
|
||||||
|
|
||||||
def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
"""Remove normalization layers from state_dict."""
|
"""Remove normalization layers from state_dict."""
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
|
||||||
@@ -210,7 +209,7 @@ def remove_normalization_layers(state_dict: Dict[str, torch.Tensor]) -> Dict[str
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
|
||||||
"""Load model state_dict and config from hub."""
|
"""Load model state_dict and config from hub."""
|
||||||
# Download files
|
# Download files
|
||||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||||
|
|||||||
Reference in New Issue
Block a user