From b632490b4b0d7aa9c42dd7a4f06768585944ed19 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 23 Jul 2025 09:26:10 +0200 Subject: [PATCH] feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. --- .../processor/migrate_policy_normalization.py | 231 +++++++++++------- 1 file changed, 144 insertions(+), 87 deletions(-) diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 6f296e5c5..f11744d5e 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -20,29 +20,33 @@ Generic script to migrate any policy model with normalization layers to the new This script: 1. Loads an existing pretrained policy model 2. Extracts normalization statistics from the model -3. Creates a NormalizerProcessor with these statistics +3. Creates both preprocessor and postprocessor: + - Preprocessor: normalizes both inputs (observations) and outputs (actions) for training + - Postprocessor: unnormalizes outputs (actions) for inference 4. Removes normalization layers from the model state_dict -5. Saves the new model and processor +5. Saves the new model and both processors Usage: - python scripts/migration/migrate_policy_normalization.py \ + python src/lerobot/processor/migrate_policy_normalization.py \ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ --policy-type act \ --push-to-hub """ import argparse +import importlib import json import os +from copy import deepcopy from pathlib import Path from typing import Any import torch -from huggingface_hub import HfApi, hf_hub_download -from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.processor.normalize_processor import NormalizerProcessor +from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor from lerobot.processor.pipeline import RobotProcessor # Policy type to class mapping @@ -63,32 +67,41 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str """Extract normalization statistics from model state_dict.""" stats = {} - # Common patterns for normalization layers + # Define patterns to match and their prefixes to remove normalization_patterns = [ - ("normalize_inputs.buffer_", "unnormalize_outputs.buffer_"), # Common pattern - ("normalize.", "unnormalize."), # Alternative pattern - ("input_normalizer.", "output_normalizer."), # SAC pattern + "normalize_inputs.buffer_", + "unnormalize_outputs.buffer_", + "normalize_targets.buffer_", + "normalize.", # Must come after normalize_* patterns + "unnormalize.", # Must come after unnormalize_* patterns + "input_normalizer.", + "output_normalizer.", ] - # Extract all normalization buffers + # Process each key in state_dict for key, tensor in state_dict.items(): - for norm_prefix, _ in normalization_patterns: - if norm_prefix in key: - # Extract the feature name and stat type - parts = key.replace(norm_prefix, "").split(".") + # Try each pattern + for pattern in normalization_patterns: + if key.startswith(pattern): + # Extract the remaining part after the pattern + remaining = key[len(pattern) :] + parts = remaining.split(".") + + # Need at least feature name and stat type if len(parts) >= 2: - # Handle keys like "buffer_observation_state.mean" - feature_parts = parts[:-1] + # Last part is the stat type (mean, std, min, max, etc.) stat_type = parts[-1] + # Everything else is the feature name + feature_name = ".".join(parts[:-1]).replace("_", ".") - # Reconstruct feature name (e.g., "observation.state") - feature_name = ".".join(feature_parts).replace("_", ".") - + # Add to stats if feature_name not in stats: stats[feature_name] = {} - stats[feature_name][stat_type] = tensor.clone() + # Only process the first matching pattern + break + return stats @@ -194,6 +207,7 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str remove_patterns = [ "normalize_inputs.", "unnormalize_outputs.", + "normalize_targets.", # Added pattern for target normalization "normalize.", "unnormalize.", "input_normalizer.", @@ -209,6 +223,30 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str return new_state_dict +def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert features from old format to PolicyFeature objects.""" + converted_features = {} + + for key, feature_dict in features_dict.items(): + # Determine feature type based on key + if "image" in key or "visual" in key: + feature_type = FeatureType.VISUAL + elif "state" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE + + # Get shape from feature dict + shape = feature_dict.get("shape", feature_dict.get("dim")) + shape = (shape,) if isinstance(shape, int) else tuple(shape) + + converted_features[key] = PolicyFeature(feature_type, shape) + + return converted_features + + def load_model_from_hub(repo_id: str, revision: str = None) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: """Load model state_dict and config from hub.""" # Download files @@ -236,13 +274,6 @@ def main(): required=True, help="Path to pretrained model (hub repo or local directory)", ) - parser.add_argument( - "--policy-type", - type=str, - required=True, - choices=list(POLICY_CLASSES.keys()), - help="Type of policy model", - ) parser.add_argument( "--output-dir", type=str, @@ -276,28 +307,15 @@ def main(): print("Extracting normalization statistics...") stats = extract_normalization_stats(state_dict) - if not stats: - print("Warning: No normalization statistics found in model. The model might already be migrated.") - else: - print(f"Found normalization statistics for: {list(stats.keys())}") + print(f"Found normalization statistics for: {list(stats.keys())}") - # Detect features and normalization modes + # Detect input features and normalization modes print("Detecting features and normalization modes...") features, norm_map = detect_features_and_norm_modes(config, stats) print(f"Detected features: {list(features.keys())}") print(f"Normalization modes: {norm_map}") - # Create NormalizerProcessor - print("Creating NormalizerProcessor...") - if stats: - processor = RobotProcessor( - [NormalizerProcessor(features, norm_map, stats)], name=f"{args.policy_type}_normalizer" - ) - else: - # No normalization needed - processor = RobotProcessor([], name=f"{args.policy_type}_normalizer") - # Remove normalization layers from state_dict print("Removing normalization layers from model...") new_state_dict = remove_normalization_layers(state_dict) @@ -317,20 +335,82 @@ def main(): output_dir.mkdir(parents=True, exist_ok=True) - # Save migrated model - print(f"Saving migrated model to {output_dir}...") - save_safetensors(new_state_dict, output_dir / "model.safetensors") - # Clean up config - remove normalization_mapping field cleaned_config = dict(config) if "normalization_mapping" in cleaned_config: print("Removing 'normalization_mapping' field from config") del cleaned_config["normalization_mapping"] + policy_type = deepcopy(cleaned_config["type"]) - # Save cleaned config - with open(output_dir / "config.json", "w") as f: - json.dump(cleaned_config, f, indent=2) + del cleaned_config["type"] + # Instantiate the policy model with cleaned config and load the cleaned state dict + print(f"Instantiating {policy_type} policy model...") + policy_class_path = POLICY_CLASSES[policy_type] + module_path, class_name = policy_class_path.rsplit(".", 1) + + module = importlib.import_module(module_path) + policy_class = getattr(module, class_name) + + # Create config class instance + config_module_path = module_path.replace("modeling", "configuration") + config_module = importlib.import_module(config_module_path) + # Handle special cases for config class names + config_class_names = { + "act": "ACTConfig", + "diffusion": "DiffusionConfig", + "pi0": "PI0Config", + "pi0fast": "PI0FASTConfig", + "smolvla": "SmolVLAConfig", + "tdmpc": "TDMPCConfig", + "vqbet": "VQBeTConfig", + "sac": "SACConfig", + "classifier": "ClassifierConfig", + } + config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config") + config_class = getattr(config_module, config_class_name) + + # Convert input_features and output_features to PolicyFeature objects - these are mandatory + if "input_features" not in cleaned_config: + raise ValueError("Missing mandatory 'input_features' in config") + if "output_features" not in cleaned_config: + raise ValueError("Missing mandatory 'output_features' in config") + + cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"]) + cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"]) + + # Create config instance from cleaned config dict + policy_config = config_class(**cleaned_config) + + # Create policy instance - some policies expect dataset_stats + policy = policy_class(policy_config) + + # Load the cleaned state dict + policy.load_state_dict(new_state_dict, strict=True) + print("Successfully loaded cleaned state dict into policy model") + + # Now create preprocessor and postprocessor with cleaned_config available + print("Creating preprocessor and postprocessor...") + # The pattern from existing processor factories: + # - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features + # - Postprocessor has one UnnormalizerProcessor for output_features only + + # Get features from cleaned_config (now they're PolicyFeature objects) + input_features = cleaned_config.get("input_features", {}) + output_features = cleaned_config.get("output_features", {}) + + # Create preprocessor with two normalizers (following the pattern from processor factories) + preprocessor_steps = [ + NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats), + NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), + ] + preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor") + + # Create postprocessor with unnormalizer for outputs only + postprocessor_steps = [ + UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), + ] + postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor") # Determine hub repo ID if pushing to hub if args.push_to_hub: if args.hub_repo_id: @@ -344,50 +424,27 @@ def main(): else: hub_repo_id = None - # Save processor (and optionally push to hub) - processor_dir = output_dir / "processor" - processor.save_pretrained( - processor_dir, - repo_id=hub_repo_id, - push_to_hub=args.push_to_hub, - private=args.private, - commit_message=f"Upload {args.policy_type} normalization processor", + # Save model using the policy's save_pretrained method + print(f"Saving model to {output_dir}...") + policy.save_pretrained( + output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private ) - print(f"Saved processor to {processor_dir}") - # If pushing to hub, also upload the model files + # Save preprocessor and postprocessor to root directory + print(f"Saving preprocessor to {output_dir}...") + preprocessor.save_pretrained(output_dir) if args.push_to_hub: - print(f"Pushing to hub repository: {hub_repo_id}") + preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) - # For the model, we still need to use API since it's just safetensors + config - api = HfApi() - api.create_repo(repo_id=hub_repo_id, repo_type="model", private=args.private, exist_ok=True) - - # Upload model files - api.upload_file( - path_or_fileobj=str(output_dir / "model.safetensors"), - path_in_repo="model.safetensors", - repo_id=hub_repo_id, - repo_type="model", - commit_message=f"Upload {args.policy_type} model weights without normalization", - ) - - api.upload_file( - path_or_fileobj=str(output_dir / "config.json"), - path_in_repo="config.json", - repo_id=hub_repo_id, - repo_type="model", - commit_message=f"Upload {args.policy_type} model config", - ) - - print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") + print(f"Saving postprocessor to {output_dir}...") + postprocessor.save_pretrained(output_dir) + if args.push_to_hub: + postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) print("\nMigration complete!") print(f"Migrated model saved to: {output_dir}") - if processor.steps: - print("Normalization processor created with statistics") - else: - print("No normalization processor needed (model had no normalization layers)") + if args.push_to_hub: + print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") if __name__ == "__main__":