mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure
- Introduced RenameProcessor in the preprocessor to handle renaming features. - Combined input and output features in a single NormalizerProcessor for improved efficiency. - Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor. - Added DeviceProcessor to both preprocessor and postprocessor for better device management.
This commit is contained in:
@@ -47,8 +47,10 @@ from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
from lerobot.processor.device_processor import DeviceProcessor
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
@@ -410,18 +412,23 @@ def main():
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
||||
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=policy_config.device),
|
||||
]
|
||||
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
|
||||
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
ToBatchProcessor(),
|
||||
]
|
||||
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
|
||||
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user