diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index c0b12c121..49f1e0f95 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -14,8 +14,8 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig -from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig -from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig +from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig @@ -24,8 +24,8 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig __all__ = [ "ACTConfig", "DiffusionConfig", - "PI0OpenPIConfig", - "PI05OpenPIConfig", + "PI0Config", + "PI05Config", "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index da66ac400..197d61944 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,9 +31,9 @@ from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -87,13 +87,13 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: return PI0FASTPolicy elif name == "pi0": - from lerobot.policies.pi0.modeling_pi0openpi import PI0OpenPIPolicy + from lerobot.policies.pi0.modeling_pi0openpi import PI0Policy - return PI0OpenPIPolicy + return PI0Policy elif name == "pi05": - from lerobot.policies.pi05.modeling_pi05openpi import PI05OpenPIPolicy + from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy - return PI05OpenPIPolicy + return PI05Policy elif name == "sac": from lerobot.policies.sac.modeling_sac import SACPolicy @@ -140,9 +140,9 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) elif policy_type == "pi0": - return PI0OpenPIConfig(**kwargs) + return PI0Config(**kwargs) elif policy_type == "pi05": - return PI05OpenPIConfig(**kwargs) + return PI05Config(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla": diff --git a/src/lerobot/policies/pi0/__init__.py b/src/lerobot/policies/pi0/__init__.py index 12d766633..15f89bf55 100644 --- a/src/lerobot/policies/pi0/__init__.py +++ b/src/lerobot/policies/pi0/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration_pi0openpi import PI0OpenPIConfig -from .modeling_pi0openpi import PI0OpenPIPolicy +from .configuration_pi0 import PI0Config +from .modeling_pi0 import PI0Policy -__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"] +__all__ = ["PI0Config", "PI0Policy"] diff --git a/src/lerobot/policies/pi0/configuration_pi0openpi.py b/src/lerobot/policies/pi0/configuration_pi0.py similarity index 100% rename from src/lerobot/policies/pi0/configuration_pi0openpi.py rename to src/lerobot/policies/pi0/configuration_pi0.py diff --git a/src/lerobot/policies/pi0/modeling_pi0openpi.py b/src/lerobot/policies/pi0/modeling_pi0.py similarity index 100% rename from src/lerobot/policies/pi0/modeling_pi0openpi.py rename to src/lerobot/policies/pi0/modeling_pi0.py diff --git a/src/lerobot/policies/pi05/__init__.py b/src/lerobot/policies/pi05/__init__.py index 2b438db85..161d8fbc9 100644 --- a/src/lerobot/policies/pi05/__init__.py +++ b/src/lerobot/policies/pi05/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration_pi05openpi import PI05OpenPIConfig -from .modeling_pi05openpi import PI05OpenPIPolicy +from .configuration_pi05 import PI05Config +from .modeling_pi05 import PI05Policy -__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"] +__all__ = ["PI05Config", "PI05Policy"] diff --git a/src/lerobot/policies/pi05/configuration_pi05openpi.py b/src/lerobot/policies/pi05/configuration_pi05.py similarity index 100% rename from src/lerobot/policies/pi05/configuration_pi05openpi.py rename to src/lerobot/policies/pi05/configuration_pi05.py diff --git a/src/lerobot/policies/pi05/modeling_pi05openpi.py b/src/lerobot/policies/pi05/modeling_pi05.py similarity index 100% rename from src/lerobot/policies/pi05/modeling_pi05openpi.py rename to src/lerobot/policies/pi05/modeling_pi05.py diff --git a/tests/policies/pi0_pi05/test_pi0_pi05_hub.py b/tests/policies/pi0_pi05/test_pi0_pi05_hub.py index 92e918422..63125e871 100644 --- a/tests/policies/pi0_pi05/test_pi0_pi05_hub.py +++ b/tests/policies/pi0_pi05/test_pi0_pi05_hub.py @@ -19,7 +19,7 @@ pytestmark = pytest.mark.skipif( ) from lerobot.policies.pi0 import PI0Policy # noqa: E402 -from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy # noqa: E402 +from lerobot.policies.pi05.modeling_pi05 import PI05Policy # noqa: E402 def create_dummy_stats(config):