mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure
- Introduced RenameProcessor in the preprocessor to handle renaming features. - Combined input and output features in a single NormalizerProcessor for improved efficiency. - Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor. - Added DeviceProcessor to both preprocessor and postprocessor for better device management.
This commit is contained in:
@@ -47,8 +47,10 @@ from safetensors.torch import load_file as load_safetensors
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||||
|
from lerobot.processor.device_processor import DeviceProcessor
|
||||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
from lerobot.processor.pipeline import RobotProcessor
|
||||||
|
from lerobot.processor.rename_processor import RenameProcessor
|
||||||
|
|
||||||
# Policy type to class mapping
|
# Policy type to class mapping
|
||||||
POLICY_CLASSES = {
|
POLICY_CLASSES = {
|
||||||
@@ -410,18 +412,23 @@ def main():
|
|||||||
|
|
||||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||||
preprocessor_steps = [
|
preprocessor_steps = [
|
||||||
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
RenameProcessor(rename_map={}),
|
||||||
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
NormalizerProcessor(
|
||||||
|
features={**input_features, **output_features},
|
||||||
|
norm_map=norm_map,
|
||||||
|
stats=stats,
|
||||||
|
),
|
||||||
ToBatchProcessor(),
|
ToBatchProcessor(),
|
||||||
|
DeviceProcessor(device=policy_config.device),
|
||||||
]
|
]
|
||||||
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
|
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
|
||||||
|
|
||||||
# Create postprocessor with unnormalizer for outputs only
|
# Create postprocessor with unnormalizer for outputs only
|
||||||
postprocessor_steps = [
|
postprocessor_steps = [
|
||||||
|
DeviceProcessor(device="cpu"),
|
||||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
ToBatchProcessor(),
|
|
||||||
]
|
]
|
||||||
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
|
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
|
||||||
|
|
||||||
# Determine hub repo ID if pushing to hub
|
# Determine hub repo ID if pushing to hub
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
|
|||||||
Reference in New Issue
Block a user