From e7be2fd113ebbb9af51530ad407e464c7f5e892a Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 23 Jul 2025 11:36:17 +0200 Subject: [PATCH] feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. --- .../processor/migrate_policy_normalization.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index f11744d5e..9032ba48e 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -411,6 +411,7 @@ def main(): UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), ] postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor") + # Determine hub repo ID if pushing to hub if args.push_to_hub: if args.hub_repo_id: @@ -424,12 +425,6 @@ def main(): else: hub_repo_id = None - # Save model using the policy's save_pretrained method - print(f"Saving model to {output_dir}...") - policy.save_pretrained( - output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private - ) - # Save preprocessor and postprocessor to root directory print(f"Saving preprocessor to {output_dir}...") preprocessor.save_pretrained(output_dir) @@ -441,6 +436,44 @@ def main(): if args.push_to_hub: postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) + # Save model using the policy's save_pretrained method + print(f"Saving model to {output_dir}...") + policy.save_pretrained( + output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private + ) + + # Generate and save model card + print("Generating model card...") + # Get metadata from original config + dataset_repo_id = config.get("repo_id", "unknown") + license = config.get("license", "apache-2.0") + + tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] + tags = set(tags).union({"robotics", "lerobot", policy_type}) + tags = list(tags) + + # Generate model card + card = policy.generate_model_card( + dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags + ) + + # Save model card locally + card.save(str(output_dir / "README.md")) + print(f"Model card saved to {output_dir / 'README.md'}") + # Push model card to hub if requested + if args.push_to_hub: + from huggingface_hub import HfApi + + api = HfApi() + api.upload_file( + path_or_fileobj=str(output_dir / "README.md"), + path_in_repo="README.md", + repo_id=hub_repo_id, + repo_type="model", + commit_message="Add model card for migrated model", + ) + print("Model card pushed to hub") + print("\nMigration complete!") print(f"Migrated model saved to: {output_dir}") if args.push_to_hub: