rename pi0/pi05 files

This commit is contained in:
Pepijn
2025-09-23 09:48:45 +02:00
parent d691d1e4fe
commit 969e8eeae1
9 changed files with 19 additions and 19 deletions
+4 -4
View File
@@ -14,8 +14,8 @@
from .act.configuration_act import ACTConfig as ACTConfig from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -24,8 +24,8 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [ __all__ = [
"ACTConfig", "ACTConfig",
"DiffusionConfig", "DiffusionConfig",
"PI0OpenPIConfig", "PI0Config",
"PI05OpenPIConfig", "PI05Config",
"SmolVLAConfig", "SmolVLAConfig",
"TDMPCConfig", "TDMPCConfig",
"VQBeTConfig", "VQBeTConfig",
+8 -8
View File
@@ -31,9 +31,9 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -87,13 +87,13 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
return PI0FASTPolicy return PI0FASTPolicy
elif name == "pi0": elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0openpi import PI0OpenPIPolicy from lerobot.policies.pi0.modeling_pi0openpi import PI0Policy
return PI0OpenPIPolicy return PI0Policy
elif name == "pi05": elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05openpi import PI05OpenPIPolicy from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy
return PI05OpenPIPolicy return PI05Policy
elif name == "sac": elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -140,9 +140,9 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
elif policy_type == "pi0fast": elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs) return PI0FASTConfig(**kwargs)
elif policy_type == "pi0": elif policy_type == "pi0":
return PI0OpenPIConfig(**kwargs) return PI0Config(**kwargs)
elif policy_type == "pi05": elif policy_type == "pi05":
return PI05OpenPIConfig(**kwargs) return PI05Config(**kwargs)
elif policy_type == "sac": elif policy_type == "sac":
return SACConfig(**kwargs) return SACConfig(**kwargs)
elif policy_type == "smolvla": elif policy_type == "smolvla":
+3 -3
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .configuration_pi0openpi import PI0OpenPIConfig from .configuration_pi0 import PI0Config
from .modeling_pi0openpi import PI0OpenPIPolicy from .modeling_pi0 import PI0Policy
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"] __all__ = ["PI0Config", "PI0Policy"]
+3 -3
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .configuration_pi05openpi import PI05OpenPIConfig from .configuration_pi05 import PI05Config
from .modeling_pi05openpi import PI05OpenPIPolicy from .modeling_pi05 import PI05Policy
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"] __all__ = ["PI05Config", "PI05Policy"]
+1 -1
View File
@@ -19,7 +19,7 @@ pytestmark = pytest.mark.skipif(
) )
from lerobot.policies.pi0 import PI0Policy # noqa: E402 from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy # noqa: E402 from lerobot.policies.pi05.modeling_pi05 import PI05Policy # noqa: E402
def create_dummy_stats(config): def create_dummy_stats(config):