mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor(processor): migrate policy normalization to use factory functions
- Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase. - Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture. - Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config. - Enhanced type safety and readability by refining feature type and normalization mode handling.
This commit is contained in:
@@ -37,13 +37,19 @@ Usage:
|
|||||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||||
--policy-type act \
|
--policy-type act \
|
||||||
--push-to-hub
|
--push-to-hub
|
||||||
|
|
||||||
|
Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
|
||||||
|
factory functions from `lerobot.policies.factory` to create processors and configurations,
|
||||||
|
ensuring consistency with the current codebase.
|
||||||
|
|
||||||
|
The script extracts normalization statistics from the old model's state_dict, creates clean
|
||||||
|
processor pipelines using the factory functions, and saves a migrated model that is compatible
|
||||||
|
with the new PolicyProcessorPipeline architecture.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -52,25 +58,7 @@ from huggingface_hub import hf_hub_download
|
|||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
|
||||||
from .batch_processor import AddBatchDimensionProcessorStep
|
|
||||||
from .device_processor import DeviceProcessorStep
|
|
||||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
|
||||||
from .pipeline import PolicyProcessorPipeline
|
|
||||||
from .rename_processor import RenameObservationsProcessorStep
|
|
||||||
|
|
||||||
# Policy type to class mapping
|
|
||||||
POLICY_CLASSES = {
|
|
||||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
|
||||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
|
||||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
|
||||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
|
||||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
|
||||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
|
||||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
|
||||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
|
||||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
@@ -158,25 +146,36 @@ def detect_features_and_norm_modes(
|
|||||||
if "normalization_mapping" in config:
|
if "normalization_mapping" in config:
|
||||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||||
# Extract normalization modes from config
|
# Extract normalization modes from config
|
||||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
for feature_type_str, mode_str in config["normalization_mapping"].items():
|
||||||
# Convert string to NormalizationMode enum
|
# Convert string to FeatureType enum
|
||||||
if mode_str == "mean_std":
|
try:
|
||||||
mode = NormalizationMode.MEAN_STD
|
if feature_type_str == "VISUAL":
|
||||||
elif mode_str == "min_max":
|
feature_type = FeatureType.VISUAL
|
||||||
mode = NormalizationMode.MIN_MAX
|
elif feature_type_str == "STATE":
|
||||||
else:
|
feature_type = FeatureType.STATE
|
||||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
elif feature_type_str == "ACTION":
|
||||||
|
feature_type = FeatureType.ACTION
|
||||||
|
else:
|
||||||
|
print(f"Warning: Unknown feature type '{feature_type_str}', skipping")
|
||||||
|
continue
|
||||||
|
except (AttributeError, ValueError):
|
||||||
|
print(f"Warning: Could not parse feature type '{feature_type_str}', skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Determine feature type from feature name
|
# Convert string to NormalizationMode enum
|
||||||
if "image" in feature_name or "visual" in feature_name:
|
try:
|
||||||
feature_type = FeatureType.VISUAL
|
if mode_str == "MEAN_STD":
|
||||||
elif "state" in feature_name:
|
mode = NormalizationMode.MEAN_STD
|
||||||
feature_type = FeatureType.STATE
|
elif mode_str == "MIN_MAX":
|
||||||
elif "action" in feature_name:
|
mode = NormalizationMode.MIN_MAX
|
||||||
feature_type = FeatureType.ACTION
|
else:
|
||||||
else:
|
print(
|
||||||
feature_type = FeatureType.STATE
|
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except (AttributeError, ValueError):
|
||||||
|
print(f"Warning: Could not parse normalization mode '{mode_str}', skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
norm_modes[feature_type] = mode
|
norm_modes[feature_type] = mode
|
||||||
|
|
||||||
@@ -303,7 +302,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
|||||||
|
|
||||||
# Get shape from feature dict
|
# Get shape from feature dict
|
||||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else ()
|
||||||
|
|
||||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||||
|
|
||||||
@@ -311,7 +310,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
|||||||
|
|
||||||
|
|
||||||
def load_model_from_hub(
|
def load_model_from_hub(
|
||||||
repo_id: str, revision: str = None
|
repo_id: str, revision: str | None = None
|
||||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||||
@@ -368,6 +367,12 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy-type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Policy type (act, diffusion, pi0, pi0fast, smolvla, tdmpc, vqbet, sac, classifier)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -416,91 +421,51 @@ def main():
|
|||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Clean up config - remove normalization_mapping field
|
# Clean up config - remove fields that shouldn't be passed to config constructor
|
||||||
cleaned_config = dict(config)
|
cleaned_config = dict(config)
|
||||||
if "normalization_mapping" in cleaned_config:
|
|
||||||
print("Removing 'normalization_mapping' field from config")
|
|
||||||
del cleaned_config["normalization_mapping"]
|
|
||||||
policy_type = deepcopy(cleaned_config["type"])
|
|
||||||
|
|
||||||
del cleaned_config["type"]
|
# Remove fields that are not part of the config class constructors
|
||||||
|
fields_to_remove = ["normalization_mapping", "type"]
|
||||||
|
for field in fields_to_remove:
|
||||||
|
if field in cleaned_config:
|
||||||
|
print(f"Removing '{field}' field from config")
|
||||||
|
del cleaned_config[field]
|
||||||
|
|
||||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
# Use the policy type from command line argument
|
||||||
print(f"Instantiating {policy_type} policy model...")
|
policy_type = args.policy_type
|
||||||
policy_class_path = POLICY_CLASSES[policy_type]
|
|
||||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
|
||||||
|
|
||||||
module = importlib.import_module(module_path)
|
# Convert input_features and output_features to PolicyFeature objects if they exist
|
||||||
policy_class = getattr(module, class_name)
|
if "input_features" in cleaned_config:
|
||||||
|
cleaned_config["input_features"] = convert_features_to_policy_features(
|
||||||
|
cleaned_config["input_features"]
|
||||||
|
)
|
||||||
|
if "output_features" in cleaned_config:
|
||||||
|
cleaned_config["output_features"] = convert_features_to_policy_features(
|
||||||
|
cleaned_config["output_features"]
|
||||||
|
)
|
||||||
|
|
||||||
# Create config class instance
|
# Add normalization mapping to config
|
||||||
config_module_path = module_path.replace("modeling", "configuration")
|
cleaned_config["normalization_mapping"] = norm_map
|
||||||
config_module = importlib.import_module(config_module_path)
|
|
||||||
# Handle special cases for config class names
|
|
||||||
config_class_names = {
|
|
||||||
"act": "ACTConfig",
|
|
||||||
"diffusion": "DiffusionConfig",
|
|
||||||
"pi0": "PI0Config",
|
|
||||||
"pi0fast": "PI0FASTConfig",
|
|
||||||
"smolvla": "SmolVLAConfig",
|
|
||||||
"tdmpc": "TDMPCConfig",
|
|
||||||
"vqbet": "VQBeTConfig",
|
|
||||||
"sac": "SACConfig",
|
|
||||||
"classifier": "ClassifierConfig",
|
|
||||||
}
|
|
||||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
|
||||||
config_class = getattr(config_module, config_class_name)
|
|
||||||
|
|
||||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
# Create policy configuration using the factory
|
||||||
if "input_features" not in cleaned_config:
|
print(f"Creating {policy_type} policy configuration...")
|
||||||
raise ValueError("Missing mandatory 'input_features' in config")
|
policy_config = make_policy_config(policy_type, **cleaned_config)
|
||||||
if "output_features" not in cleaned_config:
|
|
||||||
raise ValueError("Missing mandatory 'output_features' in config")
|
|
||||||
|
|
||||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
# Create policy instance using the factory
|
||||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
print(f"Instantiating {policy_type} policy...")
|
||||||
|
policy_class = get_policy_class(policy_type)
|
||||||
# Create config instance from cleaned config dict
|
|
||||||
policy_config = config_class(**cleaned_config)
|
|
||||||
|
|
||||||
# Create policy instance - some policies expect dataset_stats
|
|
||||||
policy = policy_class(policy_config)
|
policy = policy_class(policy_config)
|
||||||
|
|
||||||
# Load the cleaned state dict
|
# Load the cleaned state dict
|
||||||
policy.load_state_dict(new_state_dict, strict=True)
|
policy.load_state_dict(new_state_dict, strict=True)
|
||||||
print("Successfully loaded cleaned state dict into policy model")
|
print("Successfully loaded cleaned state dict into policy model")
|
||||||
|
|
||||||
# Now create preprocessor and postprocessor with cleaned_config available
|
# Create preprocessor and postprocessor using the factory
|
||||||
print("Creating preprocessor and postprocessor...")
|
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
|
||||||
# The pattern from existing processor factories:
|
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
|
||||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
|
||||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
|
||||||
|
|
||||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
|
||||||
input_features = cleaned_config.get("input_features", {})
|
|
||||||
output_features = cleaned_config.get("output_features", {})
|
|
||||||
|
|
||||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
|
||||||
preprocessor_steps = [
|
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
|
||||||
NormalizerProcessorStep(
|
|
||||||
features={**input_features, **output_features},
|
|
||||||
norm_map=norm_map,
|
|
||||||
stats=stats,
|
|
||||||
),
|
|
||||||
AddBatchDimensionProcessorStep(),
|
|
||||||
DeviceProcessorStep(device=policy_config.device),
|
|
||||||
]
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
|
||||||
|
|
||||||
# Create postprocessor with unnormalizer for outputs only
|
|
||||||
postprocessor_steps = [
|
|
||||||
DeviceProcessorStep(device="cpu"),
|
|
||||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
|
||||||
]
|
|
||||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
|
||||||
|
|
||||||
# Determine hub repo ID if pushing to hub
|
# Determine hub repo ID if pushing to hub
|
||||||
|
hub_repo_id = None
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
if args.hub_repo_id:
|
if args.hub_repo_id:
|
||||||
hub_repo_id = args.hub_repo_id
|
hub_repo_id = args.hub_repo_id
|
||||||
@@ -510,25 +475,24 @@ def main():
|
|||||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||||
else:
|
else:
|
||||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||||
else:
|
|
||||||
hub_repo_id = None
|
|
||||||
|
|
||||||
# Save preprocessor and postprocessor to root directory
|
# Save preprocessor and postprocessor to root directory
|
||||||
print(f"Saving preprocessor to {output_dir}...")
|
print(f"Saving preprocessor to {output_dir}...")
|
||||||
preprocessor.save_pretrained(output_dir)
|
preprocessor.save_pretrained(output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub and hub_repo_id:
|
||||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||||
|
|
||||||
print(f"Saving postprocessor to {output_dir}...")
|
print(f"Saving postprocessor to {output_dir}...")
|
||||||
postprocessor.save_pretrained(output_dir)
|
postprocessor.save_pretrained(output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub and hub_repo_id:
|
||||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||||
|
|
||||||
# Save model using the policy's save_pretrained method
|
# Save model using the policy's save_pretrained method
|
||||||
print(f"Saving model to {output_dir}...")
|
print(f"Saving model to {output_dir}...")
|
||||||
policy.save_pretrained(
|
if args.push_to_hub and hub_repo_id:
|
||||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
policy.save_pretrained(output_dir, push_to_hub=True, repo_id=hub_repo_id, private=args.private)
|
||||||
)
|
else:
|
||||||
|
policy.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Generate and save model card
|
# Generate and save model card
|
||||||
print("Generating model card...")
|
print("Generating model card...")
|
||||||
@@ -549,7 +513,7 @@ def main():
|
|||||||
card.save(str(output_dir / "README.md"))
|
card.save(str(output_dir / "README.md"))
|
||||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||||
# Push model card to hub if requested
|
# Push model card to hub if requested
|
||||||
if args.push_to_hub:
|
if args.push_to_hub and hub_repo_id:
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
@@ -564,7 +528,7 @@ def main():
|
|||||||
|
|
||||||
print("\nMigration complete!")
|
print("\nMigration complete!")
|
||||||
print(f"Migrated model saved to: {output_dir}")
|
print(f"Migrated model saved to: {output_dir}")
|
||||||
if args.push_to_hub:
|
if args.push_to_hub and hub_repo_id:
|
||||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user