feat(migrate): Extend load_model_from_hub to include train configuration

- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config.
- Modified main function to handle the additional train configuration when loading models from both the hub and local paths.
- Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.
This commit is contained in:
Adil Zouitine
2025-07-24 18:13:01 +02:00
committed by Steven Palma
parent 9b1138171e
commit b95c219d96
@@ -248,12 +248,15 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
return converted_features
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
@@ -262,7 +265,10 @@ def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, t
with open(config_path) as f:
config = json.load(f)
return state_dict, config
with open(train_config_path) as f:
train_config = json.load(f)
return state_dict, config, train_config
def main():
@@ -300,9 +306,11 @@ def main():
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
else:
# Hub repository
state_dict, config = load_model_from_hub(args.pretrained_path, args.revision)
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
# Extract normalization statistics
print("Extracting normalization statistics...")
@@ -448,7 +456,7 @@ def main():
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = config.get("repo_id", "unknown")
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]