mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
feat(migrate): Extend load_model_from_hub to include train configuration
- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.
This commit is contained in:
committed by
Steven Palma
parent
9b1138171e
commit
b95c219d96
@@ -248,12 +248,15 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
|||||||
return converted_features
|
return converted_features
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
|
def load_model_from_hub(
|
||||||
|
repo_id: str, revision: str = None
|
||||||
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||||
"""Load model state_dict and config from hub."""
|
"""Load model state_dict and config from hub."""
|
||||||
# Download files
|
# Download files
|
||||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||||
|
|
||||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||||
|
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||||
|
|
||||||
# Load state_dict
|
# Load state_dict
|
||||||
state_dict = load_safetensors(safetensors_path)
|
state_dict = load_safetensors(safetensors_path)
|
||||||
@@ -262,7 +265,10 @@ def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, t
|
|||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
return state_dict, config
|
with open(train_config_path) as f:
|
||||||
|
train_config = json.load(f)
|
||||||
|
|
||||||
|
return state_dict, config, train_config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -300,9 +306,11 @@ def main():
|
|||||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||||
|
train_config = json.load(f)
|
||||||
else:
|
else:
|
||||||
# Hub repository
|
# Hub repository
|
||||||
state_dict, config = load_model_from_hub(args.pretrained_path, args.revision)
|
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||||
|
|
||||||
# Extract normalization statistics
|
# Extract normalization statistics
|
||||||
print("Extracting normalization statistics...")
|
print("Extracting normalization statistics...")
|
||||||
@@ -448,7 +456,7 @@ def main():
|
|||||||
# Generate and save model card
|
# Generate and save model card
|
||||||
print("Generating model card...")
|
print("Generating model card...")
|
||||||
# Get metadata from original config
|
# Get metadata from original config
|
||||||
dataset_repo_id = config.get("repo_id", "unknown")
|
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||||
license = config.get("license", "apache-2.0")
|
license = config.get("license", "apache-2.0")
|
||||||
|
|
||||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||||
|
|||||||
Reference in New Issue
Block a user