diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 0245b9fb7..90df31721 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -248,12 +248,15 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ return converted_features -def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: +def load_model_from_hub( + repo_id: str, revision: str = None +) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: """Load model state_dict and config from hub.""" # Download files safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) + train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) # Load state_dict state_dict = load_safetensors(safetensors_path) @@ -262,7 +265,10 @@ def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, t with open(config_path) as f: config = json.load(f) - return state_dict, config + with open(train_config_path) as f: + train_config = json.load(f) + + return state_dict, config, train_config def main(): @@ -300,9 +306,11 @@ def main(): state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) with open(os.path.join(args.pretrained_path, "config.json")) as f: config = json.load(f) + with open(os.path.join(args.pretrained_path, "train_config.json")) as f: + train_config = json.load(f) else: # Hub repository - state_dict, config = load_model_from_hub(args.pretrained_path, args.revision) + state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) # Extract normalization statistics print("Extracting normalization statistics...") @@ -448,7 +456,7 @@ def main(): # Generate and save model card print("Generating model card...") # Get metadata from original config - dataset_repo_id = config.get("repo_id", "unknown") + dataset_repo_id = train_config.get("repo_id", "unknown") license = config.get("license", "apache-2.0") tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]