diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a19ee4737..8740d529e 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -29,6 +29,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -67,6 +68,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy return PI0OpenPIPolicy + elif name == "pi05_openpi": + from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy + + return PI05OpenPIPolicy elif name == "sac": from lerobot.policies.sac.modeling_sac import SACPolicy @@ -98,6 +103,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0FASTConfig(**kwargs) elif policy_type == "pi0_openpi": return PI0OpenPIConfig(**kwargs) + elif policy_type == "pi05_openpi": + return PI05OpenPIConfig(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla":