refactor(processor): update migration script for policy normalization and hub integration

- Modified the migration script to include a branch argument for pushing to the hub, enhancing flexibility in version control.
- Improved error handling by ensuring the policy type is extracted from the configuration, promoting robustness.
- Streamlined the process of saving and pushing model components to the hub, allowing for a single commit with optional PR creation.
- Updated the commit message and description for better clarity on the migration changes and benefits, ensuring users are informed of the new architecture and usage.
This commit is contained in:
AdilZouitine
2025-09-11 21:05:20 +02:00
parent cd0098a5f7
commit f51272362c
@@ -35,8 +35,8 @@ This script performs the following steps:
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \
--push-to-hub
--push-to-hub \
--branch main
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
factory functions from `lerobot.policies.factory` to create processors and configurations,
@@ -54,7 +54,7 @@ from pathlib import Path
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@@ -368,10 +368,10 @@ def main():
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
parser.add_argument(
"--policy-type",
"--branch",
type=str,
required=True,
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
default=None,
help="Git branch to use when pushing to hub. If specified, a PR will be created automatically (default: push directly to main)",
)
args = parser.parse_args()
@@ -421,6 +421,13 @@ def main():
output_dir.mkdir(parents=True, exist_ok=True)
# Extract policy type from config
if "type" not in config:
raise ValueError("Policy type not found in config.json. The config must contain a 'type' field.")
policy_type = config["type"]
print(f"Detected policy type: {policy_type}")
# Clean up config - remove fields that shouldn't be passed to config constructor
cleaned_config = dict(config)
@@ -431,9 +438,6 @@ def main():
print(f"Removing '{field}' field from config")
del cleaned_config[field]
# Use the policy type from command line argument
policy_type = args.policy_type
# Convert input_features and output_features to PolicyFeature objects if they exist
if "input_features" in cleaned_config:
cleaned_config["input_features"] = convert_features_to_policy_features(
@@ -476,23 +480,15 @@ def main():
else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
# Save preprocessor and postprocessor to root directory
# Save all components to local directory first
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub and hub_repo_id:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir)
if args.push_to_hub and hub_repo_id:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...")
if args.push_to_hub and hub_repo_id:
policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
else:
policy.save_pretrained(output_dir)
policy.save_pretrained(output_dir)
# Generate and save model card
print("Generating model card...")
@@ -512,24 +508,111 @@ def main():
# Save model card locally
card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested
# Push all files to hub in a single operation if requested
if args.push_to_hub and hub_repo_id:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_dir / "README.md"),
path_in_repo="README.md",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Add model card for migrated model",
)
print("Model card pushed to hub")
# Determine if we should create a PR (automatically if branch is specified)
create_pr = args.branch is not None
target_location = f"branch '{args.branch}'" if args.branch else "main branch"
print(f"Pushing all migrated files to {hub_repo_id} on {target_location}...")
# Upload all files in a single commit with automatic PR creation if branch specified
commit_message = "Migrate policy to PolicyProcessorPipeline system"
commit_description = None
if create_pr:
# Separate commit description for PR body
commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline**
This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture.
## What Changed
### ✨ **New Architecture - PolicyProcessorPipeline**
Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides:
- **Modularity**: Separate preprocessing and postprocessing pipelines
- **Flexibility**: Easy to swap, configure, and debug processing steps
- **Compatibility**: Works with the latest LeRobot ecosystem
### 🔧 **Normalization Extraction**
We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers:
- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*`
- **Statistics preserved**: Mean, std, min, max values for all features
- **Clean model**: State dict now contains only core model weights
### 📦 **Files Added**
- **preprocessor_config.json**: Configuration for input preprocessing pipeline
- **postprocessor_config.json**: Configuration for output postprocessing pipeline
- **model.safetensors**: Clean model weights without normalization layers
- **config.json**: Updated model configuration
- **train_config.json**: Training configuration
- **README.md**: Updated model card with migration information
### 🚀 **Benefits**
- **Backward Compatible**: Your model behavior remains identical
- **Future Ready**: Compatible with latest LeRobot features and updates
- **Debuggable**: Easy to inspect and modify processing steps
- **Portable**: Processors can be shared and reused across models
### 💻 **Usage**
```python
# Load your migrated model
from lerobot.policies import get_policy_class
from lerobot.processor import PolicyProcessorPipeline
# The preprocessor and postprocessor are now external
preprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="preprocessor_config.json")
postprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="postprocessor_config.json")
policy = get_policy_class("your-policy-type").from_pretrained("your-model-repo")
# Process data through the pipeline
processed_batch = preprocessor(raw_batch)
action = policy(processed_batch)
final_action = postprocessor(action)
```
*Generated automatically by the LeRobot policy migration script*"""
upload_kwargs = {
"repo_id": hub_repo_id,
"folder_path": output_dir,
"repo_type": "model",
"commit_message": commit_message,
"revision": args.branch,
"create_pr": create_pr,
"allow_patterns": ["*.json", "*.safetensors", "*.md"],
"ignore_patterns": ["*.tmp", "*.log"],
}
# Add commit_description for PR body if creating PR
if create_pr and commit_description:
upload_kwargs["commit_description"] = commit_description
api.upload_folder(**upload_kwargs)
if create_pr:
print("All files pushed and pull request created successfully!")
else:
print("All files pushed to main branch successfully!")
print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub and hub_repo_id:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
if args.branch:
print(
f"Successfully pushed all files to branch '{args.branch}' and created PR on https://huggingface.co/{hub_repo_id}"
)
else:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
if args.branch:
print(f"\nView the branch at: https://huggingface.co/{hub_repo_id}/tree/{args.branch}")
print(
f"View the PR at: https://huggingface.co/{hub_repo_id}/discussions (look for the most recent PR)"
)
else:
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
if __name__ == "__main__":