mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(migrate): Add model card generation and saving to migration script
- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.
This commit is contained in:
committed by
Steven Palma
parent
b632490b4b
commit
e7be2fd113
@@ -411,6 +411,7 @@ def main():
|
|||||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
]
|
]
|
||||||
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
|
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
|
||||||
|
|
||||||
# Determine hub repo ID if pushing to hub
|
# Determine hub repo ID if pushing to hub
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
if args.hub_repo_id:
|
if args.hub_repo_id:
|
||||||
@@ -424,12 +425,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
hub_repo_id = None
|
hub_repo_id = None
|
||||||
|
|
||||||
# Save model using the policy's save_pretrained method
|
|
||||||
print(f"Saving model to {output_dir}...")
|
|
||||||
policy.save_pretrained(
|
|
||||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save preprocessor and postprocessor to root directory
|
# Save preprocessor and postprocessor to root directory
|
||||||
print(f"Saving preprocessor to {output_dir}...")
|
print(f"Saving preprocessor to {output_dir}...")
|
||||||
preprocessor.save_pretrained(output_dir)
|
preprocessor.save_pretrained(output_dir)
|
||||||
@@ -441,6 +436,44 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||||
|
|
||||||
|
# Save model using the policy's save_pretrained method
|
||||||
|
print(f"Saving model to {output_dir}...")
|
||||||
|
policy.save_pretrained(
|
||||||
|
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate and save model card
|
||||||
|
print("Generating model card...")
|
||||||
|
# Get metadata from original config
|
||||||
|
dataset_repo_id = config.get("repo_id", "unknown")
|
||||||
|
license = config.get("license", "apache-2.0")
|
||||||
|
|
||||||
|
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||||
|
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
||||||
|
tags = list(tags)
|
||||||
|
|
||||||
|
# Generate model card
|
||||||
|
card = policy.generate_model_card(
|
||||||
|
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save model card locally
|
||||||
|
card.save(str(output_dir / "README.md"))
|
||||||
|
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||||
|
# Push model card to hub if requested
|
||||||
|
if args.push_to_hub:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=str(output_dir / "README.md"),
|
||||||
|
path_in_repo="README.md",
|
||||||
|
repo_id=hub_repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
commit_message="Add model card for migrated model",
|
||||||
|
)
|
||||||
|
print("Model card pushed to hub")
|
||||||
|
|
||||||
print("\nMigration complete!")
|
print("\nMigration complete!")
|
||||||
print(f"Migrated model saved to: {output_dir}")
|
print(f"Migrated model saved to: {output_dir}")
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
|
|||||||
Reference in New Issue
Block a user