feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models

- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference.
- Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture.
- Refined the extraction and removal of normalization statistics and layers, streamlining the migration process.
- Improved error handling for missing mandatory configuration fields during model instantiation.
This commit is contained in:
Adil Zouitine
2025-07-23 09:26:10 +02:00
committed by Steven Palma
parent 9a9c7208d2
commit b632490b4b
@@ -20,29 +20,33 @@ Generic script to migrate any policy model with normalization layers to the new
This script: This script:
1. Loads an existing pretrained policy model 1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model 2. Extracts normalization statistics from the model
3. Creates a NormalizerProcessor with these statistics 3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict 4. Removes normalization layers from the model state_dict
5. Saves the new model and processor 5. Saves the new model and both processors
Usage: Usage:
python scripts/migration/migrate_policy_normalization.py \ python src/lerobot/processor/migrate_policy_normalization.py \
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \ --policy-type act \
--push-to-hub --push-to-hub
""" """
import argparse import argparse
import importlib
import json import json
import os import os
from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.normalize_processor import NormalizerProcessor from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor from lerobot.processor.pipeline import RobotProcessor
# Policy type to class mapping # Policy type to class mapping
@@ -63,32 +67,41 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
"""Extract normalization statistics from model state_dict.""" """Extract normalization statistics from model state_dict."""
stats = {} stats = {}
# Common patterns for normalization layers # Define patterns to match and their prefixes to remove
normalization_patterns = [ normalization_patterns = [
("normalize_inputs.buffer_", "unnormalize_outputs.buffer_"), # Common pattern "normalize_inputs.buffer_",
("normalize.", "unnormalize."), # Alternative pattern "unnormalize_outputs.buffer_",
("input_normalizer.", "output_normalizer."), # SAC pattern "normalize_targets.buffer_",
"normalize.", # Must come after normalize_* patterns
"unnormalize.", # Must come after unnormalize_* patterns
"input_normalizer.",
"output_normalizer.",
] ]
# Extract all normalization buffers # Process each key in state_dict
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
for norm_prefix, _ in normalization_patterns: # Try each pattern
if norm_prefix in key: for pattern in normalization_patterns:
# Extract the feature name and stat type if key.startswith(pattern):
parts = key.replace(norm_prefix, "").split(".") # Extract the remaining part after the pattern
remaining = key[len(pattern) :]
parts = remaining.split(".")
# Need at least feature name and stat type
if len(parts) >= 2: if len(parts) >= 2:
# Handle keys like "buffer_observation_state.mean" # Last part is the stat type (mean, std, min, max, etc.)
feature_parts = parts[:-1]
stat_type = parts[-1] stat_type = parts[-1]
# Everything else is the feature name
feature_name = ".".join(parts[:-1]).replace("_", ".")
# Reconstruct feature name (e.g., "observation.state") # Add to stats
feature_name = ".".join(feature_parts).replace("_", ".")
if feature_name not in stats: if feature_name not in stats:
stats[feature_name] = {} stats[feature_name] = {}
stats[feature_name][stat_type] = tensor.clone() stats[feature_name][stat_type] = tensor.clone()
# Only process the first matching pattern
break
return stats return stats
@@ -194,6 +207,7 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
remove_patterns = [ remove_patterns = [
"normalize_inputs.", "normalize_inputs.",
"unnormalize_outputs.", "unnormalize_outputs.",
"normalize_targets.", # Added pattern for target normalization
"normalize.", "normalize.",
"unnormalize.", "unnormalize.",
"input_normalizer.", "input_normalizer.",
@@ -209,6 +223,30 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
return new_state_dict return new_state_dict
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
converted_features = {}
for key, feature_dict in features_dict.items():
# Determine feature type based on key
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
# Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
converted_features[key] = PolicyFeature(feature_type, shape)
return converted_features
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
"""Load model state_dict and config from hub.""" """Load model state_dict and config from hub."""
# Download files # Download files
@@ -236,13 +274,6 @@ def main():
required=True, required=True,
help="Path to pretrained model (hub repo or local directory)", help="Path to pretrained model (hub repo or local directory)",
) )
parser.add_argument(
"--policy-type",
type=str,
required=True,
choices=list(POLICY_CLASSES.keys()),
help="Type of policy model",
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
type=str, type=str,
@@ -276,28 +307,15 @@ def main():
print("Extracting normalization statistics...") print("Extracting normalization statistics...")
stats = extract_normalization_stats(state_dict) stats = extract_normalization_stats(state_dict)
if not stats: print(f"Found normalization statistics for: {list(stats.keys())}")
print("Warning: No normalization statistics found in model. The model might already be migrated.")
else:
print(f"Found normalization statistics for: {list(stats.keys())}")
# Detect features and normalization modes # Detect input features and normalization modes
print("Detecting features and normalization modes...") print("Detecting features and normalization modes...")
features, norm_map = detect_features_and_norm_modes(config, stats) features, norm_map = detect_features_and_norm_modes(config, stats)
print(f"Detected features: {list(features.keys())}") print(f"Detected features: {list(features.keys())}")
print(f"Normalization modes: {norm_map}") print(f"Normalization modes: {norm_map}")
# Create NormalizerProcessor
print("Creating NormalizerProcessor...")
if stats:
processor = RobotProcessor(
[NormalizerProcessor(features, norm_map, stats)], name=f"{args.policy_type}_normalizer"
)
else:
# No normalization needed
processor = RobotProcessor([], name=f"{args.policy_type}_normalizer")
# Remove normalization layers from state_dict # Remove normalization layers from state_dict
print("Removing normalization layers from model...") print("Removing normalization layers from model...")
new_state_dict = remove_normalization_layers(state_dict) new_state_dict = remove_normalization_layers(state_dict)
@@ -317,20 +335,82 @@ def main():
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Save migrated model
print(f"Saving migrated model to {output_dir}...")
save_safetensors(new_state_dict, output_dir / "model.safetensors")
# Clean up config - remove normalization_mapping field # Clean up config - remove normalization_mapping field
cleaned_config = dict(config) cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config: if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config") print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"] del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
# Save cleaned config del cleaned_config["type"]
with open(output_dir / "config.json", "w") as f:
json.dump(cleaned_config, f, indent=2)
# Instantiate the policy model with cleaned config and load the cleaned state dict
print(f"Instantiating {policy_type} policy model...")
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
# Create config class instance
config_module_path = module_path.replace("modeling", "configuration")
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
if "input_features" not in cleaned_config:
raise ValueError("Missing mandatory 'input_features' in config")
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessor for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
]
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
# Determine hub repo ID if pushing to hub # Determine hub repo ID if pushing to hub
if args.push_to_hub: if args.push_to_hub:
if args.hub_repo_id: if args.hub_repo_id:
@@ -344,50 +424,27 @@ def main():
else: else:
hub_repo_id = None hub_repo_id = None
# Save processor (and optionally push to hub) # Save model using the policy's save_pretrained method
processor_dir = output_dir / "processor" print(f"Saving model to {output_dir}...")
processor.save_pretrained( policy.save_pretrained(
processor_dir, output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
repo_id=hub_repo_id,
push_to_hub=args.push_to_hub,
private=args.private,
commit_message=f"Upload {args.policy_type} normalization processor",
) )
print(f"Saved processor to {processor_dir}")
# If pushing to hub, also upload the model files # Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub: if args.push_to_hub:
print(f"Pushing to hub repository: {hub_repo_id}") preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# For the model, we still need to use API since it's just safetensors + config print(f"Saving postprocessor to {output_dir}...")
api = HfApi() postprocessor.save_pretrained(output_dir)
api.create_repo(repo_id=hub_repo_id, repo_type="model", private=args.private, exist_ok=True) if args.push_to_hub:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Upload model files
api.upload_file(
path_or_fileobj=str(output_dir / "model.safetensors"),
path_in_repo="model.safetensors",
repo_id=hub_repo_id,
repo_type="model",
commit_message=f"Upload {args.policy_type} model weights without normalization",
)
api.upload_file(
path_or_fileobj=str(output_dir / "config.json"),
path_in_repo="config.json",
repo_id=hub_repo_id,
repo_type="model",
commit_message=f"Upload {args.policy_type} model config",
)
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
print("\nMigration complete!") print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}") print(f"Migrated model saved to: {output_dir}")
if processor.steps: if args.push_to_hub:
print("Normalization processor created with statistics") print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
else:
print("No normalization processor needed (model had no normalization layers)")
if __name__ == "__main__": if __name__ == "__main__":