diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index ba5ccaba7..e10c88bf4 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -37,13 +37,19 @@ Usage: --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ --policy-type act \ --push-to-hub + +Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config` +factory functions from `lerobot.policies.factory` to create processors and configurations, +ensuring consistency with the current codebase. + +The script extracts normalization statistics from the old model's state_dict, creates clean +processor pipelines using the factory functions, and saves a migrated model that is compatible +with the new PolicyProcessorPipeline architecture. """ import argparse -import importlib import json import os -from copy import deepcopy from pathlib import Path from typing import Any @@ -52,25 +58,7 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - -from .batch_processor import AddBatchDimensionProcessorStep -from .device_processor import DeviceProcessorStep -from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep -from .pipeline import PolicyProcessorPipeline -from .rename_processor import RenameObservationsProcessorStep - -# Policy type to class mapping -POLICY_CLASSES = { - "act": "lerobot.policies.act.modeling_act.ACTPolicy", - "diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy", - "pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy", - "pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy", - "smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy", - "tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy", - "vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy", - "sac": "lerobot.policies.sac.modeling_sac.SACPolicy", - "classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy", -} +from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: @@ -158,25 +146,36 @@ def detect_features_and_norm_modes( if "normalization_mapping" in config: print(f"Found normalization_mapping in config: {config['normalization_mapping']}") # Extract normalization modes from config - for feature_name, mode_str in config["normalization_mapping"].items(): - # Convert string to NormalizationMode enum - if mode_str == "mean_std": - mode = NormalizationMode.MEAN_STD - elif mode_str == "min_max": - mode = NormalizationMode.MIN_MAX - else: - print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'") + for feature_type_str, mode_str in config["normalization_mapping"].items(): + # Convert string to FeatureType enum + try: + if feature_type_str == "VISUAL": + feature_type = FeatureType.VISUAL + elif feature_type_str == "STATE": + feature_type = FeatureType.STATE + elif feature_type_str == "ACTION": + feature_type = FeatureType.ACTION + else: + print(f"Warning: Unknown feature type '{feature_type_str}', skipping") + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse feature type '{feature_type_str}', skipping") continue - # Determine feature type from feature name - if "image" in feature_name or "visual" in feature_name: - feature_type = FeatureType.VISUAL - elif "state" in feature_name: - feature_type = FeatureType.STATE - elif "action" in feature_name: - feature_type = FeatureType.ACTION - else: - feature_type = FeatureType.STATE + # Convert string to NormalizationMode enum + try: + if mode_str == "MEAN_STD": + mode = NormalizationMode.MEAN_STD + elif mode_str == "MIN_MAX": + mode = NormalizationMode.MIN_MAX + else: + print( + f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'" + ) + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse normalization mode '{mode_str}', skipping") + continue norm_modes[feature_type] = mode @@ -303,7 +302,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ # Get shape from feature dict shape = feature_dict.get("shape", feature_dict.get("dim")) - shape = (shape,) if isinstance(shape, int) else tuple(shape) + shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else () converted_features[key] = PolicyFeature(feature_type, shape) @@ -311,7 +310,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ def load_model_from_hub( - repo_id: str, revision: str = None + repo_id: str, revision: str | None = None ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: """ Downloads and loads a model's state_dict and configs from the Hugging Face Hub. @@ -368,6 +367,12 @@ def main(): ) parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load") parser.add_argument("--private", action="store_true", help="Make the hub repository private") + parser.add_argument( + "--policy-type", + type=str, + required=True, + help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)", + ) args = parser.parse_args() @@ -416,91 +421,51 @@ def main(): output_dir.mkdir(parents=True, exist_ok=True) - # Clean up config - remove normalization_mapping field + # Clean up config - remove fields that shouldn't be passed to config constructor cleaned_config = dict(config) - if "normalization_mapping" in cleaned_config: - print("Removing 'normalization_mapping' field from config") - del cleaned_config["normalization_mapping"] - policy_type = deepcopy(cleaned_config["type"]) - del cleaned_config["type"] + # Remove fields that are not part of the config class constructors + fields_to_remove = ["normalization_mapping", "type"] + for field in fields_to_remove: + if field in cleaned_config: + print(f"Removing '{field}' field from config") + del cleaned_config[field] - # Instantiate the policy model with cleaned config and load the cleaned state dict - print(f"Instantiating {policy_type} policy model...") - policy_class_path = POLICY_CLASSES[policy_type] - module_path, class_name = policy_class_path.rsplit(".", 1) + # Use the policy type from command line argument + policy_type = args.policy_type - module = importlib.import_module(module_path) - policy_class = getattr(module, class_name) + # Convert input_features and output_features to PolicyFeature objects if they exist + if "input_features" in cleaned_config: + cleaned_config["input_features"] = convert_features_to_policy_features( + cleaned_config["input_features"] + ) + if "output_features" in cleaned_config: + cleaned_config["output_features"] = convert_features_to_policy_features( + cleaned_config["output_features"] + ) - # Create config class instance - config_module_path = module_path.replace("modeling", "configuration") - config_module = importlib.import_module(config_module_path) - # Handle special cases for config class names - config_class_names = { - "act": "ACTConfig", - "diffusion": "DiffusionConfig", - "pi0": "PI0Config", - "pi0fast": "PI0FASTConfig", - "smolvla": "SmolVLAConfig", - "tdmpc": "TDMPCConfig", - "vqbet": "VQBeTConfig", - "sac": "SACConfig", - "classifier": "ClassifierConfig", - } - config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config") - config_class = getattr(config_module, config_class_name) + # Add normalization mapping to config + cleaned_config["normalization_mapping"] = norm_map - # Convert input_features and output_features to PolicyFeature objects - these are mandatory - if "input_features" not in cleaned_config: - raise ValueError("Missing mandatory 'input_features' in config") - if "output_features" not in cleaned_config: - raise ValueError("Missing mandatory 'output_features' in config") + # Create policy configuration using the factory + print(f"Creating {policy_type} policy configuration...") + policy_config = make_policy_config(policy_type, **cleaned_config) - cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"]) - cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"]) - - # Create config instance from cleaned config dict - policy_config = config_class(**cleaned_config) - - # Create policy instance - some policies expect dataset_stats + # Create policy instance using the factory + print(f"Instantiating {policy_type} policy...") + policy_class = get_policy_class(policy_type) policy = policy_class(policy_config) # Load the cleaned state dict policy.load_state_dict(new_state_dict, strict=True) print("Successfully loaded cleaned state dict into policy model") - # Now create preprocessor and postprocessor with cleaned_config available - print("Creating preprocessor and postprocessor...") - # The pattern from existing processor factories: - # - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features - # - Postprocessor has one UnnormalizerProcessorStep for output_features only - - # Get features from cleaned_config (now they're PolicyFeature objects) - input_features = cleaned_config.get("input_features", {}) - output_features = cleaned_config.get("output_features", {}) - - # Create preprocessor with two normalizers (following the pattern from processor factories) - preprocessor_steps = [ - RenameObservationsProcessorStep(rename_map={}), - NormalizerProcessorStep( - features={**input_features, **output_features}, - norm_map=norm_map, - stats=stats, - ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=policy_config.device), - ] - preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor") - - # Create postprocessor with unnormalizer for outputs only - postprocessor_steps = [ - DeviceProcessorStep(device="cpu"), - UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats), - ] - postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor") + # Create preprocessor and postprocessor using the factory + print("Creating preprocessor and postprocessor using make_pre_post_processors...") + preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats) # Determine hub repo ID if pushing to hub + hub_repo_id = None if args.push_to_hub: if args.hub_repo_id: hub_repo_id = args.hub_repo_id @@ -510,25 +475,24 @@ def main(): hub_repo_id = f"{args.pretrained_path}_migrated" else: raise ValueError("--hub-repo-id must be specified when pushing local model to hub") - else: - hub_repo_id = None # Save preprocessor and postprocessor to root directory print(f"Saving preprocessor to {output_dir}...") preprocessor.save_pretrained(output_dir) - if args.push_to_hub: + if args.push_to_hub and hub_repo_id: preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) print(f"Saving postprocessor to {output_dir}...") postprocessor.save_pretrained(output_dir) - if args.push_to_hub: + if args.push_to_hub and hub_repo_id: postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) # Save model using the policy's save_pretrained method print(f"Saving model to {output_dir}...") - policy.save_pretrained( - output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private - ) + if args.push_to_hub and hub_repo_id: + policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private) + else: + policy.save_pretrained(output_dir) # Generate and save model card print("Generating model card...") @@ -549,7 +513,7 @@ def main(): card.save(str(output_dir / "README.md")) print(f"Model card saved to {output_dir / 'README.md'}") # Push model card to hub if requested - if args.push_to_hub: + if args.push_to_hub and hub_repo_id: from huggingface_hub import HfApi api = HfApi() @@ -564,7 +528,7 @@ def main(): print("\nMigration complete!") print(f"Migrated model saved to: {output_dir}") - if args.push_to_hub: + if args.push_to_hub and hub_repo_id: print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")