mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-28 21:57:27 +00:00
add pi05 to factory
This commit is contained in:
@@ -29,6 +29,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -67,6 +68,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
|
||||
return PI0OpenPIPolicy
|
||||
elif name == "pi05_openpi":
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
|
||||
return PI05OpenPIPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -98,6 +103,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "pi0_openpi":
|
||||
return PI0OpenPIConfig(**kwargs)
|
||||
elif policy_type == "pi05_openpi":
|
||||
return PI05OpenPIConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
|
||||
Reference in New Issue
Block a user