refactor(processor): migrate policy normalization to use factory functions

- Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase.
- Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture.
- Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config.
- Enhanced type safety and readability by refining feature type and normalization mode handling.
This commit is contained in:
AdilZouitine
2025-09-11 18:14:13 +02:00
parent aeb70812c1
commit efde42d4a9
@@ -37,13 +37,19 @@ Usage:
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \
--push-to-hub
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
factory functions from `lerobot.policies.factory` to create processors and configurations,
ensuring consistency with the current codebase.
The script extracts normalization statistics from the old model's state_dict, creates clean
processor pipelines using the factory functions, and saves a migrated model that is compatible
with the new PolicyProcessorPipeline architecture.
"""
import argparse
import importlib
import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any
@@ -52,25 +58,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from .batch_processor import AddBatchDimensionProcessorStep
from .device_processor import DeviceProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from .pipeline import PolicyProcessorPipeline
from .rename_processor import RenameObservationsProcessorStep
# Policy type to class mapping
POLICY_CLASSES = {
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
}
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
@@ -158,25 +146,36 @@ def detect_features_and_norm_modes(
if "normalization_mapping" in config:
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
# Extract normalization modes from config
for feature_name, mode_str in config["normalization_mapping"].items():
# Convert string to NormalizationMode enum
if mode_str == "mean_std":
mode = NormalizationMode.MEAN_STD
elif mode_str == "min_max":
mode = NormalizationMode.MIN_MAX
else:
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
for feature_type_str, mode_str in config["normalization_mapping"].items():
# Convert string to FeatureType enum
try:
if feature_type_str == "VISUAL":
feature_type = FeatureType.VISUAL
elif feature_type_str == "STATE":
feature_type = FeatureType.STATE
elif feature_type_str == "ACTION":
feature_type = FeatureType.ACTION
else:
print(f"Warning: Unknown feature type '{feature_type_str}', skipping")
continue
except (AttributeError, ValueError):
print(f"Warning: Could not parse feature type '{feature_type_str}', skipping")
continue
# Determine feature type from feature name
if "image" in feature_name or "visual" in feature_name:
feature_type = FeatureType.VISUAL
elif "state" in feature_name:
feature_type = FeatureType.STATE
elif "action" in feature_name:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
# Convert string to NormalizationMode enum
try:
if mode_str == "MEAN_STD":
mode = NormalizationMode.MEAN_STD
elif mode_str == "MIN_MAX":
mode = NormalizationMode.MIN_MAX
else:
print(
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
)
continue
except (AttributeError, ValueError):
print(f"Warning: Could not parse normalization mode '{mode_str}', skipping")
continue
norm_modes[feature_type] = mode
@@ -303,7 +302,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
# Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else ()
converted_features[key] = PolicyFeature(feature_type, shape)
@@ -311,7 +310,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
def load_model_from_hub(
repo_id: str, revision: str = None
repo_id: str, revision: str | None = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
@@ -368,6 +367,12 @@ def main():
)
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
parser.add_argument(
"--policy-type",
type=str,
required=True,
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
)
args = parser.parse_args()
@@ -416,91 +421,51 @@ def main():
output_dir.mkdir(parents=True, exist_ok=True)
# Clean up config - remove normalization_mapping field
# Clean up config - remove fields that shouldn't be passed to config constructor
cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
del cleaned_config["type"]
# Remove fields that are not part of the config class constructors
fields_to_remove = ["normalization_mapping", "type"]
for field in fields_to_remove:
if field in cleaned_config:
print(f"Removing '{field}' field from config")
del cleaned_config[field]
# Instantiate the policy model with cleaned config and load the cleaned state dict
print(f"Instantiating {policy_type} policy model...")
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
# Use the policy type from command line argument
policy_type = args.policy_type
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
# Convert input_features and output_features to PolicyFeature objects if they exist
if "input_features" in cleaned_config:
cleaned_config["input_features"] = convert_features_to_policy_features(
cleaned_config["input_features"]
)
if "output_features" in cleaned_config:
cleaned_config["output_features"] = convert_features_to_policy_features(
cleaned_config["output_features"]
)
# Create config class instance
config_module_path = module_path.replace("modeling", "configuration")
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Add normalization mapping to config
cleaned_config["normalization_mapping"] = norm_map
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
if "input_features" not in cleaned_config:
raise ValueError("Missing mandatory 'input_features' in config")
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
# Create policy configuration using the factory
print(f"Creating {policy_type} policy configuration...")
policy_config = make_policy_config(policy_type, **cleaned_config)
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
# Create policy instance using the factory
print(f"Instantiating {policy_type} policy...")
policy_class = get_policy_class(policy_type)
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameObservationsProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=policy_config.device),
]
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
# Create preprocessor and postprocessor using the factory
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
# Determine hub repo ID if pushing to hub
hub_repo_id = None
if args.push_to_hub:
if args.hub_repo_id:
hub_repo_id = args.hub_repo_id
@@ -510,25 +475,24 @@ def main():
hub_repo_id = f"{args.pretrained_path}_migrated"
else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
else:
hub_repo_id = None
# Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub:
if args.push_to_hub and hub_repo_id:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir)
if args.push_to_hub:
if args.push_to_hub and hub_repo_id:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...")
policy.save_pretrained(
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
)
if args.push_to_hub and hub_repo_id:
policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
else:
policy.save_pretrained(output_dir)
# Generate and save model card
print("Generating model card...")
@@ -549,7 +513,7 @@ def main():
card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested
if args.push_to_hub:
if args.push_to_hub and hub_repo_id:
from huggingface_hub import HfApi
api = HfApi()
@@ -564,7 +528,7 @@ def main():
print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub:
if args.push_to_hub and hub_repo_id:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")