mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models
- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation.
This commit is contained in:
committed by
Steven Palma
parent
9a9c7208d2
commit
b632490b4b
@@ -20,29 +20,33 @@ Generic script to migrate any policy model with normalization layers to the new
|
|||||||
This script:
|
This script:
|
||||||
1. Loads an existing pretrained policy model
|
1. Loads an existing pretrained policy model
|
||||||
2. Extracts normalization statistics from the model
|
2. Extracts normalization statistics from the model
|
||||||
3. Creates a NormalizerProcessor with these statistics
|
3. Creates both preprocessor and postprocessor:
|
||||||
|
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||||
|
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||||
4. Removes normalization layers from the model state_dict
|
4. Removes normalization layers from the model state_dict
|
||||||
5. Saves the new model and processor
|
5. Saves the new model and both processors
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python scripts/migration/migrate_policy_normalization.py \
|
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||||
--policy-type act \
|
--policy-type act \
|
||||||
--push-to-hub
|
--push-to-hub
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.processor.normalize_processor import NormalizerProcessor
|
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
from lerobot.processor.pipeline import RobotProcessor
|
||||||
|
|
||||||
# Policy type to class mapping
|
# Policy type to class mapping
|
||||||
@@ -63,32 +67,41 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
|
|||||||
"""Extract normalization statistics from model state_dict."""
|
"""Extract normalization statistics from model state_dict."""
|
||||||
stats = {}
|
stats = {}
|
||||||
|
|
||||||
# Common patterns for normalization layers
|
# Define patterns to match and their prefixes to remove
|
||||||
normalization_patterns = [
|
normalization_patterns = [
|
||||||
("normalize_inputs.buffer_", "unnormalize_outputs.buffer_"), # Common pattern
|
"normalize_inputs.buffer_",
|
||||||
("normalize.", "unnormalize."), # Alternative pattern
|
"unnormalize_outputs.buffer_",
|
||||||
("input_normalizer.", "output_normalizer."), # SAC pattern
|
"normalize_targets.buffer_",
|
||||||
|
"normalize.", # Must come after normalize_* patterns
|
||||||
|
"unnormalize.", # Must come after unnormalize_* patterns
|
||||||
|
"input_normalizer.",
|
||||||
|
"output_normalizer.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract all normalization buffers
|
# Process each key in state_dict
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
for norm_prefix, _ in normalization_patterns:
|
# Try each pattern
|
||||||
if norm_prefix in key:
|
for pattern in normalization_patterns:
|
||||||
# Extract the feature name and stat type
|
if key.startswith(pattern):
|
||||||
parts = key.replace(norm_prefix, "").split(".")
|
# Extract the remaining part after the pattern
|
||||||
|
remaining = key[len(pattern) :]
|
||||||
|
parts = remaining.split(".")
|
||||||
|
|
||||||
|
# Need at least feature name and stat type
|
||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
# Handle keys like "buffer_observation_state.mean"
|
# Last part is the stat type (mean, std, min, max, etc.)
|
||||||
feature_parts = parts[:-1]
|
|
||||||
stat_type = parts[-1]
|
stat_type = parts[-1]
|
||||||
|
# Everything else is the feature name
|
||||||
|
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
||||||
|
|
||||||
# Reconstruct feature name (e.g., "observation.state")
|
# Add to stats
|
||||||
feature_name = ".".join(feature_parts).replace("_", ".")
|
|
||||||
|
|
||||||
if feature_name not in stats:
|
if feature_name not in stats:
|
||||||
stats[feature_name] = {}
|
stats[feature_name] = {}
|
||||||
|
|
||||||
stats[feature_name][stat_type] = tensor.clone()
|
stats[feature_name][stat_type] = tensor.clone()
|
||||||
|
|
||||||
|
# Only process the first matching pattern
|
||||||
|
break
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
@@ -194,6 +207,7 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
|||||||
remove_patterns = [
|
remove_patterns = [
|
||||||
"normalize_inputs.",
|
"normalize_inputs.",
|
||||||
"unnormalize_outputs.",
|
"unnormalize_outputs.",
|
||||||
|
"normalize_targets.", # Added pattern for target normalization
|
||||||
"normalize.",
|
"normalize.",
|
||||||
"unnormalize.",
|
"unnormalize.",
|
||||||
"input_normalizer.",
|
"input_normalizer.",
|
||||||
@@ -209,6 +223,30 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||||
|
"""Convert features from old format to PolicyFeature objects."""
|
||||||
|
converted_features = {}
|
||||||
|
|
||||||
|
for key, feature_dict in features_dict.items():
|
||||||
|
# Determine feature type based on key
|
||||||
|
if "image" in key or "visual" in key:
|
||||||
|
feature_type = FeatureType.VISUAL
|
||||||
|
elif "state" in key:
|
||||||
|
feature_type = FeatureType.STATE
|
||||||
|
elif "action" in key:
|
||||||
|
feature_type = FeatureType.ACTION
|
||||||
|
else:
|
||||||
|
feature_type = FeatureType.STATE
|
||||||
|
|
||||||
|
# Get shape from feature dict
|
||||||
|
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||||
|
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||||
|
|
||||||
|
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||||
|
|
||||||
|
return converted_features
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
|
def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
|
||||||
"""Load model state_dict and config from hub."""
|
"""Load model state_dict and config from hub."""
|
||||||
# Download files
|
# Download files
|
||||||
@@ -236,13 +274,6 @@ def main():
|
|||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model (hub repo or local directory)",
|
help="Path to pretrained model (hub repo or local directory)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--policy-type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
choices=list(POLICY_CLASSES.keys()),
|
|
||||||
help="Type of policy model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -276,28 +307,15 @@ def main():
|
|||||||
print("Extracting normalization statistics...")
|
print("Extracting normalization statistics...")
|
||||||
stats = extract_normalization_stats(state_dict)
|
stats = extract_normalization_stats(state_dict)
|
||||||
|
|
||||||
if not stats:
|
print(f"Found normalization statistics for: {list(stats.keys())}")
|
||||||
print("Warning: No normalization statistics found in model. The model might already be migrated.")
|
|
||||||
else:
|
|
||||||
print(f"Found normalization statistics for: {list(stats.keys())}")
|
|
||||||
|
|
||||||
# Detect features and normalization modes
|
# Detect input features and normalization modes
|
||||||
print("Detecting features and normalization modes...")
|
print("Detecting features and normalization modes...")
|
||||||
features, norm_map = detect_features_and_norm_modes(config, stats)
|
features, norm_map = detect_features_and_norm_modes(config, stats)
|
||||||
|
|
||||||
print(f"Detected features: {list(features.keys())}")
|
print(f"Detected features: {list(features.keys())}")
|
||||||
print(f"Normalization modes: {norm_map}")
|
print(f"Normalization modes: {norm_map}")
|
||||||
|
|
||||||
# Create NormalizerProcessor
|
|
||||||
print("Creating NormalizerProcessor...")
|
|
||||||
if stats:
|
|
||||||
processor = RobotProcessor(
|
|
||||||
[NormalizerProcessor(features, norm_map, stats)], name=f"{args.policy_type}_normalizer"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No normalization needed
|
|
||||||
processor = RobotProcessor([], name=f"{args.policy_type}_normalizer")
|
|
||||||
|
|
||||||
# Remove normalization layers from state_dict
|
# Remove normalization layers from state_dict
|
||||||
print("Removing normalization layers from model...")
|
print("Removing normalization layers from model...")
|
||||||
new_state_dict = remove_normalization_layers(state_dict)
|
new_state_dict = remove_normalization_layers(state_dict)
|
||||||
@@ -317,20 +335,82 @@ def main():
|
|||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Save migrated model
|
|
||||||
print(f"Saving migrated model to {output_dir}...")
|
|
||||||
save_safetensors(new_state_dict, output_dir / "model.safetensors")
|
|
||||||
|
|
||||||
# Clean up config - remove normalization_mapping field
|
# Clean up config - remove normalization_mapping field
|
||||||
cleaned_config = dict(config)
|
cleaned_config = dict(config)
|
||||||
if "normalization_mapping" in cleaned_config:
|
if "normalization_mapping" in cleaned_config:
|
||||||
print("Removing 'normalization_mapping' field from config")
|
print("Removing 'normalization_mapping' field from config")
|
||||||
del cleaned_config["normalization_mapping"]
|
del cleaned_config["normalization_mapping"]
|
||||||
|
policy_type = deepcopy(cleaned_config["type"])
|
||||||
|
|
||||||
# Save cleaned config
|
del cleaned_config["type"]
|
||||||
with open(output_dir / "config.json", "w") as f:
|
|
||||||
json.dump(cleaned_config, f, indent=2)
|
|
||||||
|
|
||||||
|
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||||
|
print(f"Instantiating {policy_type} policy model...")
|
||||||
|
policy_class_path = POLICY_CLASSES[policy_type]
|
||||||
|
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
policy_class = getattr(module, class_name)
|
||||||
|
|
||||||
|
# Create config class instance
|
||||||
|
config_module_path = module_path.replace("modeling", "configuration")
|
||||||
|
config_module = importlib.import_module(config_module_path)
|
||||||
|
# Handle special cases for config class names
|
||||||
|
config_class_names = {
|
||||||
|
"act": "ACTConfig",
|
||||||
|
"diffusion": "DiffusionConfig",
|
||||||
|
"pi0": "PI0Config",
|
||||||
|
"pi0fast": "PI0FASTConfig",
|
||||||
|
"smolvla": "SmolVLAConfig",
|
||||||
|
"tdmpc": "TDMPCConfig",
|
||||||
|
"vqbet": "VQBeTConfig",
|
||||||
|
"sac": "SACConfig",
|
||||||
|
"classifier": "ClassifierConfig",
|
||||||
|
}
|
||||||
|
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||||
|
config_class = getattr(config_module, config_class_name)
|
||||||
|
|
||||||
|
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||||
|
if "input_features" not in cleaned_config:
|
||||||
|
raise ValueError("Missing mandatory 'input_features' in config")
|
||||||
|
if "output_features" not in cleaned_config:
|
||||||
|
raise ValueError("Missing mandatory 'output_features' in config")
|
||||||
|
|
||||||
|
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||||
|
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||||
|
|
||||||
|
# Create config instance from cleaned config dict
|
||||||
|
policy_config = config_class(**cleaned_config)
|
||||||
|
|
||||||
|
# Create policy instance - some policies expect dataset_stats
|
||||||
|
policy = policy_class(policy_config)
|
||||||
|
|
||||||
|
# Load the cleaned state dict
|
||||||
|
policy.load_state_dict(new_state_dict, strict=True)
|
||||||
|
print("Successfully loaded cleaned state dict into policy model")
|
||||||
|
|
||||||
|
# Now create preprocessor and postprocessor with cleaned_config available
|
||||||
|
print("Creating preprocessor and postprocessor...")
|
||||||
|
# The pattern from existing processor factories:
|
||||||
|
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
|
||||||
|
# - Postprocessor has one UnnormalizerProcessor for output_features only
|
||||||
|
|
||||||
|
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||||
|
input_features = cleaned_config.get("input_features", {})
|
||||||
|
output_features = cleaned_config.get("output_features", {})
|
||||||
|
|
||||||
|
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||||
|
preprocessor_steps = [
|
||||||
|
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
||||||
|
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
|
]
|
||||||
|
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
|
||||||
|
|
||||||
|
# Create postprocessor with unnormalizer for outputs only
|
||||||
|
postprocessor_steps = [
|
||||||
|
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
|
]
|
||||||
|
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
|
||||||
# Determine hub repo ID if pushing to hub
|
# Determine hub repo ID if pushing to hub
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
if args.hub_repo_id:
|
if args.hub_repo_id:
|
||||||
@@ -344,50 +424,27 @@ def main():
|
|||||||
else:
|
else:
|
||||||
hub_repo_id = None
|
hub_repo_id = None
|
||||||
|
|
||||||
# Save processor (and optionally push to hub)
|
# Save model using the policy's save_pretrained method
|
||||||
processor_dir = output_dir / "processor"
|
print(f"Saving model to {output_dir}...")
|
||||||
processor.save_pretrained(
|
policy.save_pretrained(
|
||||||
processor_dir,
|
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||||
repo_id=hub_repo_id,
|
|
||||||
push_to_hub=args.push_to_hub,
|
|
||||||
private=args.private,
|
|
||||||
commit_message=f"Upload {args.policy_type} normalization processor",
|
|
||||||
)
|
)
|
||||||
print(f"Saved processor to {processor_dir}")
|
|
||||||
|
|
||||||
# If pushing to hub, also upload the model files
|
# Save preprocessor and postprocessor to root directory
|
||||||
|
print(f"Saving preprocessor to {output_dir}...")
|
||||||
|
preprocessor.save_pretrained(output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
print(f"Pushing to hub repository: {hub_repo_id}")
|
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||||
|
|
||||||
# For the model, we still need to use API since it's just safetensors + config
|
print(f"Saving postprocessor to {output_dir}...")
|
||||||
api = HfApi()
|
postprocessor.save_pretrained(output_dir)
|
||||||
api.create_repo(repo_id=hub_repo_id, repo_type="model", private=args.private, exist_ok=True)
|
if args.push_to_hub:
|
||||||
|
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||||
# Upload model files
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=str(output_dir / "model.safetensors"),
|
|
||||||
path_in_repo="model.safetensors",
|
|
||||||
repo_id=hub_repo_id,
|
|
||||||
repo_type="model",
|
|
||||||
commit_message=f"Upload {args.policy_type} model weights without normalization",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=str(output_dir / "config.json"),
|
|
||||||
path_in_repo="config.json",
|
|
||||||
repo_id=hub_repo_id,
|
|
||||||
repo_type="model",
|
|
||||||
commit_message=f"Upload {args.policy_type} model config",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
|
||||||
|
|
||||||
print("\nMigration complete!")
|
print("\nMigration complete!")
|
||||||
print(f"Migrated model saved to: {output_dir}")
|
print(f"Migrated model saved to: {output_dir}")
|
||||||
if processor.steps:
|
if args.push_to_hub:
|
||||||
print("Normalization processor created with statistics")
|
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||||
else:
|
|
||||||
print("No normalization processor needed (model had no normalization layers)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user