mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
refactor(processor): migrate policy normalization to use factory functions
- Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase. - Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture. - Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config. - Enhanced type safety and readability by refining feature type and normalization mode handling.
This commit is contained in:
@@ -37,13 +37,19 @@ Usage:
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
|
||||
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
|
||||
factory functions from `lerobot.policies.factory` to create processors and configurations,
|
||||
ensuring consistency with the current codebase.
|
||||
|
||||
The script extracts normalization statistics from the old model's state_dict, creates clean
|
||||
processor pipelines using the factory functions, and saves a migrated model that is compatible
|
||||
with the new PolicyProcessorPipeline architecture.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -52,25 +58,7 @@ from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from .pipeline import PolicyProcessorPipeline
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
||||
}
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
@@ -158,25 +146,36 @@ def detect_features_and_norm_modes(
|
||||
if "normalization_mapping" in config:
|
||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||
# Extract normalization modes from config
|
||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to NormalizationMode enum
|
||||
if mode_str == "mean_std":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "min_max":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
||||
for feature_type_str, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to FeatureType enum
|
||||
try:
|
||||
if feature_type_str == "VISUAL":
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif feature_type_str == "STATE":
|
||||
feature_type = FeatureType.STATE
|
||||
elif feature_type_str == "ACTION":
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
print(f"Warning: Unknown feature type '{feature_type_str}', skipping")
|
||||
continue
|
||||
except (AttributeError, ValueError):
|
||||
print(f"Warning: Could not parse feature type '{feature_type_str}', skipping")
|
||||
continue
|
||||
|
||||
# Determine feature type from feature name
|
||||
if "image" in feature_name or "visual" in feature_name:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in feature_name:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in feature_name:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
# Convert string to NormalizationMode enum
|
||||
try:
|
||||
if mode_str == "MEAN_STD":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "MIN_MAX":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(
|
||||
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
|
||||
)
|
||||
continue
|
||||
except (AttributeError, ValueError):
|
||||
print(f"Warning: Could not parse normalization mode '{mode_str}', skipping")
|
||||
continue
|
||||
|
||||
norm_modes[feature_type] = mode
|
||||
|
||||
@@ -303,7 +302,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
|
||||
# Get shape from feature dict
|
||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else ()
|
||||
|
||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
@@ -311,7 +310,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
repo_id: str, revision: str | None = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
@@ -368,6 +367,12 @@ def main():
|
||||
)
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
parser.add_argument(
|
||||
"--policy-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -416,91 +421,51 @@ def main():
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clean up config - remove normalization_mapping field
|
||||
# Clean up config - remove fields that shouldn't be passed to config constructor
|
||||
cleaned_config = dict(config)
|
||||
if "normalization_mapping" in cleaned_config:
|
||||
print("Removing 'normalization_mapping' field from config")
|
||||
del cleaned_config["normalization_mapping"]
|
||||
policy_type = deepcopy(cleaned_config["type"])
|
||||
|
||||
del cleaned_config["type"]
|
||||
# Remove fields that are not part of the config class constructors
|
||||
fields_to_remove = ["normalization_mapping", "type"]
|
||||
for field in fields_to_remove:
|
||||
if field in cleaned_config:
|
||||
print(f"Removing '{field}' field from config")
|
||||
del cleaned_config[field]
|
||||
|
||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||
print(f"Instantiating {policy_type} policy model...")
|
||||
policy_class_path = POLICY_CLASSES[policy_type]
|
||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||
# Use the policy type from command line argument
|
||||
policy_type = args.policy_type
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
# Convert input_features and output_features to PolicyFeature objects if they exist
|
||||
if "input_features" in cleaned_config:
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(
|
||||
cleaned_config["input_features"]
|
||||
)
|
||||
if "output_features" in cleaned_config:
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(
|
||||
cleaned_config["output_features"]
|
||||
)
|
||||
|
||||
# Create config class instance
|
||||
config_module_path = module_path.replace("modeling", "configuration")
|
||||
config_module = importlib.import_module(config_module_path)
|
||||
# Handle special cases for config class names
|
||||
config_class_names = {
|
||||
"act": "ACTConfig",
|
||||
"diffusion": "DiffusionConfig",
|
||||
"pi0": "PI0Config",
|
||||
"pi0fast": "PI0FASTConfig",
|
||||
"smolvla": "SmolVLAConfig",
|
||||
"tdmpc": "TDMPCConfig",
|
||||
"vqbet": "VQBeTConfig",
|
||||
"sac": "SACConfig",
|
||||
"classifier": "ClassifierConfig",
|
||||
}
|
||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||
config_class = getattr(config_module, config_class_name)
|
||||
# Add normalization mapping to config
|
||||
cleaned_config["normalization_mapping"] = norm_map
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||
if "input_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'input_features' in config")
|
||||
if "output_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'output_features' in config")
|
||||
# Create policy configuration using the factory
|
||||
print(f"Creating {policy_type} policy configuration...")
|
||||
policy_config = make_policy_config(policy_type, **cleaned_config)
|
||||
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||
|
||||
# Create config instance from cleaned config dict
|
||||
policy_config = config_class(**cleaned_config)
|
||||
|
||||
# Create policy instance - some policies expect dataset_stats
|
||||
# Create policy instance using the factory
|
||||
print(f"Instantiating {policy_type} policy...")
|
||||
policy_class = get_policy_class(policy_type)
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
output_features = cleaned_config.get("output_features", {})
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=policy_config.device),
|
||||
]
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
# Create preprocessor and postprocessor using the factory
|
||||
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
hub_repo_id = None
|
||||
if args.push_to_hub:
|
||||
if args.hub_repo_id:
|
||||
hub_repo_id = args.hub_repo_id
|
||||
@@ -510,25 +475,24 @@ def main():
|
||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
else:
|
||||
hub_repo_id = None
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
policy.save_pretrained(
|
||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||
)
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
|
||||
else:
|
||||
policy.save_pretrained(output_dir)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
@@ -549,7 +513,7 @@ def main():
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
if args.push_to_hub:
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
@@ -564,7 +528,7 @@ def main():
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub:
|
||||
if args.push_to_hub and hub_repo_id:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user