refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure

- Introduced RenameProcessor in the preprocessor to handle renaming features.
- Combined input and output features in a single NormalizerProcessor for improved efficiency.
- Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor.
- Added DeviceProcessor to both preprocessor and postprocessor for better device management.
This commit is contained in:
Adil Zouitine
2025-08-07 11:04:15 +02:00
parent 862bc7ef85
commit 0524551f52
@@ -47,8 +47,10 @@ from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.device_processor import DeviceProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.rename_processor import RenameProcessor
# Policy type to class mapping
POLICY_CLASSES = {
@@ -410,18 +412,23 @@ def main():
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
ToBatchProcessor(),
DeviceProcessor(device=policy_config.device),
]
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
]
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub: