refactor(processor): migrate policy normalization to use factory functions

- Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase.
- Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture.
- Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config.
- Enhanced type safety and readability by refining feature type and normalization mode handling.
This commit is contained in:
AdilZouitine
2025-09-11 18:14:13 +02:00
parent aeb70812c1
commit efde42d4a9
@@ -37,13 +37,19 @@ Usage:
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \ --policy-type act \
--push-to-hub --push-to-hub
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
factory functions from `lerobot.policies.factory` to create processors and configurations,
ensuring consistency with the current codebase.
The script extracts normalization statistics from the old model's state_dict, creates clean
processor pipelines using the factory functions, and saves a migrated model that is compatible
with the new PolicyProcessorPipeline architecture.
""" """
import argparse import argparse
import importlib
import json import json
import os import os
from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -52,25 +58,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
from .batch_processor import AddBatchDimensionProcessorStep
from .device_processor import DeviceProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from .pipeline import PolicyProcessorPipeline
from .rename_processor import RenameObservationsProcessorStep
# Policy type to class mapping
POLICY_CLASSES = {
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
}
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
@@ -158,25 +146,36 @@ def detect_features_and_norm_modes(
if "normalization_mapping" in config: if "normalization_mapping" in config:
print(f"Found normalization_mapping in config: {config['normalization_mapping']}") print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
# Extract normalization modes from config # Extract normalization modes from config
for feature_name, mode_str in config["normalization_mapping"].items(): for feature_type_str, mode_str in config["normalization_mapping"].items():
# Convert string to NormalizationMode enum # Convert string to FeatureType enum
if mode_str == "mean_std": try:
mode = NormalizationMode.MEAN_STD if feature_type_str == "VISUAL":
elif mode_str == "min_max": feature_type = FeatureType.VISUAL
mode = NormalizationMode.MIN_MAX elif feature_type_str == "STATE":
else: feature_type = FeatureType.STATE
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'") elif feature_type_str == "ACTION":
feature_type = FeatureType.ACTION
else:
print(f"Warning: Unknown feature type '{feature_type_str}', skipping")
continue
except (AttributeError, ValueError):
print(f"Warning: Could not parse feature type '{feature_type_str}', skipping")
continue continue
# Determine feature type from feature name # Convert string to NormalizationMode enum
if "image" in feature_name or "visual" in feature_name: try:
feature_type = FeatureType.VISUAL if mode_str == "MEAN_STD":
elif "state" in feature_name: mode = NormalizationMode.MEAN_STD
feature_type = FeatureType.STATE elif mode_str == "MIN_MAX":
elif "action" in feature_name: mode = NormalizationMode.MIN_MAX
feature_type = FeatureType.ACTION else:
else: print(
feature_type = FeatureType.STATE f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
)
continue
except (AttributeError, ValueError):
print(f"Warning: Could not parse normalization mode '{mode_str}', skipping")
continue
norm_modes[feature_type] = mode norm_modes[feature_type] = mode
@@ -303,7 +302,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
# Get shape from feature dict # Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim")) shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape) shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else ()
converted_features[key] = PolicyFeature(feature_type, shape) converted_features[key] = PolicyFeature(feature_type, shape)
@@ -311,7 +310,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
def load_model_from_hub( def load_model_from_hub(
repo_id: str, revision: str = None repo_id: str, revision: str | None = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
""" """
Downloads and loads a model's state_dict and configs from the Hugging Face Hub. Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
@@ -368,6 +367,12 @@ def main():
) )
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load") parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private") parser.add_argument("--private", action="store_true", help="Make the hub repository private")
parser.add_argument(
"--policy-type",
type=str,
required=True,
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -416,91 +421,51 @@ def main():
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Clean up config - remove normalization_mapping field # Clean up config - remove fields that shouldn't be passed to config constructor
cleaned_config = dict(config) cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
del cleaned_config["type"] # Remove fields that are not part of the config class constructors
fields_to_remove = ["normalization_mapping", "type"]
for field in fields_to_remove:
if field in cleaned_config:
print(f"Removing '{field}' field from config")
del cleaned_config[field]
# Instantiate the policy model with cleaned config and load the cleaned state dict # Use the policy type from command line argument
print(f"Instantiating {policy_type} policy model...") policy_type = args.policy_type
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
module = importlib.import_module(module_path) # Convert input_features and output_features to PolicyFeature objects if they exist
policy_class = getattr(module, class_name) if "input_features" in cleaned_config:
cleaned_config["input_features"] = convert_features_to_policy_features(
cleaned_config["input_features"]
)
if "output_features" in cleaned_config:
cleaned_config["output_features"] = convert_features_to_policy_features(
cleaned_config["output_features"]
)
# Create config class instance # Add normalization mapping to config
config_module_path = module_path.replace("modeling", "configuration") cleaned_config["normalization_mapping"] = norm_map
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Convert input_features and output_features to PolicyFeature objects - these are mandatory # Create policy configuration using the factory
if "input_features" not in cleaned_config: print(f"Creating {policy_type} policy configuration...")
raise ValueError("Missing mandatory 'input_features' in config") policy_config = make_policy_config(policy_type, **cleaned_config)
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"]) # Create policy instance using the factory
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"]) print(f"Instantiating {policy_type} policy...")
policy_class = get_policy_class(policy_type)
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
policy = policy_class(policy_config) policy = policy_class(policy_config)
# Load the cleaned state dict # Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True) policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model") print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available # Create preprocessor and postprocessor using the factory
print("Creating preprocessor and postprocessor...") print("Creating preprocessor and postprocessor using make_pre_post_processors...")
# The pattern from existing processor factories: preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameObservationsProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=policy_config.device),
]
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub # Determine hub repo ID if pushing to hub
hub_repo_id = None
if args.push_to_hub: if args.push_to_hub:
if args.hub_repo_id: if args.hub_repo_id:
hub_repo_id = args.hub_repo_id hub_repo_id = args.hub_repo_id
@@ -510,25 +475,24 @@ def main():
hub_repo_id = f"{args.pretrained_path}_migrated" hub_repo_id = f"{args.pretrained_path}_migrated"
else: else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub") raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
else:
hub_repo_id = None
# Save preprocessor and postprocessor to root directory # Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...") print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir) preprocessor.save_pretrained(output_dir)
if args.push_to_hub: if args.push_to_hub and hub_repo_id:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...") print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir) postprocessor.save_pretrained(output_dir)
if args.push_to_hub: if args.push_to_hub and hub_repo_id:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method # Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...") print(f"Saving model to {output_dir}...")
policy.save_pretrained( if args.push_to_hub and hub_repo_id:
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
) else:
policy.save_pretrained(output_dir)
# Generate and save model card # Generate and save model card
print("Generating model card...") print("Generating model card...")
@@ -549,7 +513,7 @@ def main():
card.save(str(output_dir / "README.md")) card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}") print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested # Push model card to hub if requested
if args.push_to_hub: if args.push_to_hub and hub_repo_id:
from huggingface_hub import HfApi from huggingface_hub import HfApi
api = HfApi() api = HfApi()
@@ -564,7 +528,7 @@ def main():
print("\nMigration complete!") print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}") print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub: if args.push_to_hub and hub_repo_id:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")