mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
rename pi0/pi05 files
This commit is contained in:
@@ -14,8 +14,8 @@
|
|||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig
|
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
@@ -24,8 +24,8 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ACTConfig",
|
"ACTConfig",
|
||||||
"DiffusionConfig",
|
"DiffusionConfig",
|
||||||
"PI0OpenPIConfig",
|
"PI0Config",
|
||||||
"PI05OpenPIConfig",
|
"PI05Config",
|
||||||
"SmolVLAConfig",
|
"SmolVLAConfig",
|
||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ from lerobot.envs.configs import EnvConfig
|
|||||||
from lerobot.envs.utils import env_to_policy_features
|
from lerobot.envs.utils import env_to_policy_features
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
@@ -87,13 +87,13 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
return PI0FASTPolicy
|
return PI0FASTPolicy
|
||||||
elif name == "pi0":
|
elif name == "pi0":
|
||||||
from lerobot.policies.pi0.modeling_pi0openpi import PI0OpenPIPolicy
|
from lerobot.policies.pi0.modeling_pi0openpi import PI0Policy
|
||||||
|
|
||||||
return PI0OpenPIPolicy
|
return PI0Policy
|
||||||
elif name == "pi05":
|
elif name == "pi05":
|
||||||
from lerobot.policies.pi05.modeling_pi05openpi import PI05OpenPIPolicy
|
from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy
|
||||||
|
|
||||||
return PI05OpenPIPolicy
|
return PI05Policy
|
||||||
elif name == "sac":
|
elif name == "sac":
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
@@ -140,9 +140,9 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
elif policy_type == "pi0fast":
|
elif policy_type == "pi0fast":
|
||||||
return PI0FASTConfig(**kwargs)
|
return PI0FASTConfig(**kwargs)
|
||||||
elif policy_type == "pi0":
|
elif policy_type == "pi0":
|
||||||
return PI0OpenPIConfig(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
elif policy_type == "pi05":
|
elif policy_type == "pi05":
|
||||||
return PI05OpenPIConfig(**kwargs)
|
return PI05Config(**kwargs)
|
||||||
elif policy_type == "sac":
|
elif policy_type == "sac":
|
||||||
return SACConfig(**kwargs)
|
return SACConfig(**kwargs)
|
||||||
elif policy_type == "smolvla":
|
elif policy_type == "smolvla":
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .configuration_pi0openpi import PI0OpenPIConfig
|
from .configuration_pi0 import PI0Config
|
||||||
from .modeling_pi0openpi import PI0OpenPIPolicy
|
from .modeling_pi0 import PI0Policy
|
||||||
|
|
||||||
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"]
|
__all__ = ["PI0Config", "PI0Policy"]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .configuration_pi05openpi import PI05OpenPIConfig
|
from .configuration_pi05 import PI05Config
|
||||||
from .modeling_pi05openpi import PI05OpenPIPolicy
|
from .modeling_pi05 import PI05Policy
|
||||||
|
|
||||||
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"]
|
__all__ = ["PI05Config", "PI05Policy"]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ pytestmark = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||||
from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy # noqa: E402
|
from lerobot.policies.pi05.modeling_pi05 import PI05Policy # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_stats(config):
|
def create_dummy_stats(config):
|
||||||
|
|||||||
Reference in New Issue
Block a user