diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index e5e2700c7..f2ab93b57 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -17,8 +17,9 @@ from __future__ import annotations import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, cast +import torch from torch import nn from typing_extensions import Unpack @@ -41,7 +42,7 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor.pipeline import RobotProcessor -def get_policy_class(name: str) -> PreTrainedPolicy: +def get_policy_class(name: str) -> type[PreTrainedPolicy]: """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" if name == "tdmpc": from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy @@ -113,6 +114,7 @@ class ProcessorConfigKwargs(TypedDict, total=False): postprocessor_config_filename: str | None preprocessor_overrides: dict[str, Any] | None postprocessor_overrides: dict[str, Any] | None + dataset_stats: dict[str, dict[str, torch.Tensor]] | None def make_processor( @@ -155,49 +157,68 @@ def make_processor( # Create a new processor based on policy type if policy_cfg.type == "tdmpc": + from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor - processors = make_tdmpc_processor(policy_cfg, **kwargs) + processors = make_tdmpc_processor( + config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "diffusion": from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor - processors = make_diffusion_processor(policy_cfg, **kwargs) + processors = make_diffusion_processor( + cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "act": from lerobot.policies.act.processor_act import make_act_processor - processors = make_act_processor(policy_cfg, **kwargs) + processors = make_act_processor( + config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "vqbet": from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor - processors = make_vqbet_processor(policy_cfg, **kwargs) + processors = make_vqbet_processor( + config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "pi0": from lerobot.policies.pi0.processor_pi0 import make_pi0_processor - processors = make_pi0_processor(policy_cfg, **kwargs) + processors = make_pi0_processor( + config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "pi0fast": from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor - processors = make_pi0fast_processor(policy_cfg, **kwargs) + processors = make_pi0fast_processor( + cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "sac": from lerobot.policies.sac.processor_sac import make_sac_processor - processors = make_sac_processor(policy_cfg, **kwargs) + processors = make_sac_processor( + cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "reward_classifier": from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor - processors = make_classifier_processor(policy_cfg, **kwargs) + processors = make_classifier_processor( + cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) elif policy_cfg.type == "smolvla": from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor - processors = make_smolvla_processor(policy_cfg, **kwargs) + processors = make_smolvla_processor( + cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats") + ) else: raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") @@ -258,6 +279,8 @@ def make_policy( "rather than a dataset. Normalization modules inside the policy will have infinite values " "by default without stats from a dataset." ) + if env_cfg is None: + raise ValueError("env_cfg cannot be None when ds_meta is not provided") features = env_to_policy_features(env_cfg) cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}