refactor(factory): Update processor configuration and type hints

- Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety.
- Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility.
- Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations.
- Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling.
This commit is contained in:
Adil Zouitine
2025-08-05 16:54:02 +02:00
committed by Steven Palma
parent 87890cbf38
commit 7fc7ec75bb
+34 -11
View File
@@ -17,8 +17,9 @@
from __future__ import annotations
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
import torch
from torch import nn
from typing_extensions import Unpack
@@ -41,7 +42,7 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import RobotProcessor
def get_policy_class(name: str) -> PreTrainedPolicy:
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -113,6 +114,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
def make_processor(
@@ -155,49 +157,68 @@ def make_processor(
# Create a new processor based on policy type
if policy_cfg.type == "tdmpc":
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
processors = make_tdmpc_processor(policy_cfg, **kwargs)
processors = make_tdmpc_processor(
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "diffusion":
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
processors = make_diffusion_processor(policy_cfg, **kwargs)
processors = make_diffusion_processor(
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "act":
from lerobot.policies.act.processor_act import make_act_processor
processors = make_act_processor(policy_cfg, **kwargs)
processors = make_act_processor(
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "vqbet":
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
processors = make_vqbet_processor(policy_cfg, **kwargs)
processors = make_vqbet_processor(
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0":
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
processors = make_pi0_processor(policy_cfg, **kwargs)
processors = make_pi0_processor(
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0fast":
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
processors = make_pi0fast_processor(policy_cfg, **kwargs)
processors = make_pi0fast_processor(
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "sac":
from lerobot.policies.sac.processor_sac import make_sac_processor
processors = make_sac_processor(policy_cfg, **kwargs)
processors = make_sac_processor(
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "reward_classifier":
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(policy_cfg, **kwargs)
processors = make_classifier_processor(
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "smolvla":
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
processors = make_smolvla_processor(policy_cfg, **kwargs)
processors = make_smolvla_processor(
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -258,6 +279,8 @@ def make_policy(
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
if env_cfg is None:
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}