add pi05 to factory

This commit is contained in:
Pepijn
2025-09-11 11:01:31 +02:00
parent 8d1434c069
commit 384ec52ec7
+7
View File
@@ -29,6 +29,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -67,6 +68,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
return PI0OpenPIPolicy
elif name == "pi05_openpi":
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
return PI05OpenPIPolicy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -98,6 +103,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0_openpi":
return PI0OpenPIConfig(**kwargs)
elif policy_type == "pi05_openpi":
return PI05OpenPIConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":