diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 197d61944..bba1c894f 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -87,11 +87,11 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: return PI0FASTPolicy elif name == "pi0": - from lerobot.policies.pi0.modeling_pi0openpi import PI0Policy + from lerobot.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy elif name == "pi05": - from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy + from lerobot.policies.pi05.modeling_pi05 import PI05Policy return PI05Policy elif name == "sac":