This commit is contained in:
Pepijn
2025-09-24 11:33:24 +02:00
parent d71f0d4a05
commit 6bd23541d6
8 changed files with 11 additions and 12 deletions
-1
View File
@@ -144,7 +144,6 @@ python src/lerobot/scripts/eval.py \
--eval.batch_size=1 \
--eval.n_episodes=5 \
--policy.path=pepijn223/pi0_libero_fp32 \
--env.multitask_eval=true \
--output_dir=./eval_logs/ \
--policy.compile_model=false \
--policy.gradient_checkpointing=false \
@@ -24,7 +24,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0OpenPIConfig(PreTrainedConfig):
class PI0Config(PreTrainedConfig):
# Model architecture
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
+1 -1
View File
@@ -900,7 +900,7 @@ class PI0Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
"⚠️ DISCLAIMER: The PI0 model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
+2 -2
View File
@@ -34,7 +34,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi05.configuration_pi05openpi import PI05Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -874,7 +874,7 @@ class PI05Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
"⚠️ DISCLAIMER: The PI05 model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy, only meant to be run locally!"""
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
import os
+4 -4
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
"""Test script to verify PI0OpenPI policy integration with LeRobot, only meant to be run locally!"""
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
import os
@@ -14,14 +14,14 @@ pytestmark = pytest.mark.skipif(
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0 import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_policy_instantiation():
# Create config
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -61,7 +61,7 @@ def test_policy_instantiation():
}
# Instantiate policy
policy = PI0OpenPIPolicy(config, dataset_stats)
policy = PI0Policy(config, dataset_stats)
# Test forward pass with dummy data
batch_size = 1
@@ -1,4 +1,4 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os
+1 -1
View File
@@ -2,7 +2,7 @@
# TODO(pepijn): Remove these tests before merging
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
"""Test script to load PI0 model from HuggingFace hub and run inference."""
import os