diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index f11744d5e..9032ba48e 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -411,6 +411,7 @@ def main(): UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), ] postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor") + # Determine hub repo ID if pushing to hub if args.push_to_hub: if args.hub_repo_id: @@ -424,12 +425,6 @@ def main(): else: hub_repo_id = None - # Save model using the policy's save_pretrained method - print(f"Saving model to {output_dir}...") - policy.save_pretrained( - output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private - ) - # Save preprocessor and postprocessor to root directory print(f"Saving preprocessor to {output_dir}...") preprocessor.save_pretrained(output_dir) @@ -441,6 +436,44 @@ def main(): if args.push_to_hub: postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) + # Save model using the policy's save_pretrained method + print(f"Saving model to {output_dir}...") + policy.save_pretrained( + output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private + ) + + # Generate and save model card + print("Generating model card...") + # Get metadata from original config + dataset_repo_id = config.get("repo_id", "unknown") + license = config.get("license", "apache-2.0") + + tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] + tags = set(tags).union({"robotics", "lerobot", policy_type}) + tags = list(tags) + + # Generate model card + card = policy.generate_model_card( + dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags + ) + + # Save model card locally + card.save(str(output_dir / "README.md")) + print(f"Model card saved to {output_dir / 'README.md'}") + # Push model card to hub if requested + if args.push_to_hub: + from huggingface_hub import HfApi + + api = HfApi() + api.upload_file( + path_or_fileobj=str(output_dir / "README.md"), + path_in_repo="README.md", + repo_id=hub_repo_id, + repo_type="model", + commit_message="Add model card for migrated model", + ) + print("Model card pushed to hub") + print("\nMigration complete!") print(f"Migrated model saved to: {output_dir}") if args.push_to_hub: