mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
refactor(processor): update migration script for policy normalization and hub integration
- Modified the migration script to include a branch argument for pushing to the hub, enhancing flexibility in version control. - Improved error handling by ensuring the policy type is extracted from the configuration, promoting robustness. - Streamlined the process of saving and pushing model components to the hub, allowing for a single commit with optional PR creation. - Updated the commit message and description for better clarity on the migration changes and benefits, ensuring users are informed of the new architecture and usage.
This commit is contained in:
@@ -35,8 +35,8 @@ This script performs the following steps:
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
--push-to-hub \
|
||||
--branch main
|
||||
|
||||
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
|
||||
factory functions from `lerobot.policies.factory` to create processors and configurations,
|
||||
@@ -54,7 +54,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
@@ -368,10 +368,10 @@ def main():
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
parser.add_argument(
|
||||
"--policy-type",
|
||||
"--branch",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
|
||||
default=None,
|
||||
help="Git branch to use when pushing to hub. If specified, a PR will be created automatically (default: push directly to main)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -421,6 +421,13 @@ def main():
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract policy type from config
|
||||
if "type" not in config:
|
||||
raise ValueError("Policy type not found in config.json. The config must contain a 'type' field.")
|
||||
|
||||
policy_type = config["type"]
|
||||
print(f"Detected policy type: {policy_type}")
|
||||
|
||||
# Clean up config - remove fields that shouldn't be passed to config constructor
|
||||
cleaned_config = dict(config)
|
||||
|
||||
@@ -431,9 +438,6 @@ def main():
|
||||
print(f"Removing '{field}' field from config")
|
||||
del cleaned_config[field]
|
||||
|
||||
# Use the policy type from command line argument
|
||||
policy_type = args.policy_type
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects if they exist
|
||||
if "input_features" in cleaned_config:
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(
|
||||
@@ -476,23 +480,15 @@ def main():
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
# Save all components to local directory first
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
|
||||
else:
|
||||
policy.save_pretrained(output_dir)
|
||||
policy.save_pretrained(output_dir)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
@@ -512,24 +508,111 @@ def main():
|
||||
# Save model card locally
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
# Push all files to hub in a single operation if requested
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_dir / "README.md"),
|
||||
path_in_repo="README.md",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Add model card for migrated model",
|
||||
)
|
||||
print("Model card pushed to hub")
|
||||
|
||||
# Determine if we should create a PR (automatically if branch is specified)
|
||||
create_pr = args.branch is not None
|
||||
target_location = f"branch '{args.branch}'" if args.branch else "main branch"
|
||||
|
||||
print(f"Pushing all migrated files to {hub_repo_id} on {target_location}...")
|
||||
|
||||
# Upload all files in a single commit with automatic PR creation if branch specified
|
||||
commit_message = "Migrate policy to PolicyProcessorPipeline system"
|
||||
commit_description = None
|
||||
|
||||
if create_pr:
|
||||
# Separate commit description for PR body
|
||||
commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline**
|
||||
|
||||
This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture.
|
||||
|
||||
## What Changed
|
||||
|
||||
### ✨ **New Architecture - PolicyProcessorPipeline**
|
||||
Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides:
|
||||
- **Modularity**: Separate preprocessing and postprocessing pipelines
|
||||
- **Flexibility**: Easy to swap, configure, and debug processing steps
|
||||
- **Compatibility**: Works with the latest LeRobot ecosystem
|
||||
|
||||
### 🔧 **Normalization Extraction**
|
||||
We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers:
|
||||
- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*`
|
||||
- **Statistics preserved**: Mean, std, min, max values for all features
|
||||
- **Clean model**: State dict now contains only core model weights
|
||||
|
||||
### 📦 **Files Added**
|
||||
- **preprocessor_config.json**: Configuration for input preprocessing pipeline
|
||||
- **postprocessor_config.json**: Configuration for output postprocessing pipeline
|
||||
- **model.safetensors**: Clean model weights without normalization layers
|
||||
- **config.json**: Updated model configuration
|
||||
- **train_config.json**: Training configuration
|
||||
- **README.md**: Updated model card with migration information
|
||||
|
||||
### 🚀 **Benefits**
|
||||
- **Backward Compatible**: Your model behavior remains identical
|
||||
- **Future Ready**: Compatible with latest LeRobot features and updates
|
||||
- **Debuggable**: Easy to inspect and modify processing steps
|
||||
- **Portable**: Processors can be shared and reused across models
|
||||
|
||||
### 💻 **Usage**
|
||||
```python
|
||||
# Load your migrated model
|
||||
from lerobot.policies import get_policy_class
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
# The preprocessor and postprocessor are now external
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="preprocessor_config.json")
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="postprocessor_config.json")
|
||||
policy = get_policy_class("your-policy-type").from_pretrained("your-model-repo")
|
||||
|
||||
# Process data through the pipeline
|
||||
processed_batch = preprocessor(raw_batch)
|
||||
action = policy(processed_batch)
|
||||
final_action = postprocessor(action)
|
||||
```
|
||||
|
||||
*Generated automatically by the LeRobot policy migration script*"""
|
||||
|
||||
upload_kwargs = {
|
||||
"repo_id": hub_repo_id,
|
||||
"folder_path": output_dir,
|
||||
"repo_type": "model",
|
||||
"commit_message": commit_message,
|
||||
"revision": args.branch,
|
||||
"create_pr": create_pr,
|
||||
"allow_patterns": ["*.json", "*.safetensors", "*.md"],
|
||||
"ignore_patterns": ["*.tmp", "*.log"],
|
||||
}
|
||||
|
||||
# Add commit_description for PR body if creating PR
|
||||
if create_pr and commit_description:
|
||||
upload_kwargs["commit_description"] = commit_description
|
||||
|
||||
api.upload_folder(**upload_kwargs)
|
||||
|
||||
if create_pr:
|
||||
print("All files pushed and pull request created successfully!")
|
||||
else:
|
||||
print("All files pushed to main branch successfully!")
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
if args.branch:
|
||||
print(
|
||||
f"Successfully pushed all files to branch '{args.branch}' and created PR on https://huggingface.co/{hub_repo_id}"
|
||||
)
|
||||
else:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
if args.branch:
|
||||
print(f"\nView the branch at: https://huggingface.co/{hub_repo_id}/tree/{args.branch}")
|
||||
print(
|
||||
f"View the PR at: https://huggingface.co/{hub_repo_id}/discussions (look for the most recent PR)"
|
||||
)
|
||||
else:
|
||||
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user