refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy

- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.
This commit is contained in:
AdilZouitine
2025-09-23 10:03:39 +02:00
parent 9d58086912
commit 28fa7eae72
10 changed files with 42 additions and 48 deletions
+7 -9
View File
@@ -14,10 +14,10 @@ pytestmark = pytest.mark.skipif(
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0_openpi import ( # noqa: E402
PI0OpenPIConfig,
PI0OpenPIPolicy,
make_pi0_openpi_pre_post_processors, # noqa: E402
from lerobot.policies.pi0 import ( # noqa: E402
PI0Config,
PI0Policy,
make_pi0_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@@ -30,7 +30,7 @@ def test_policy_instantiation():
# Create config
set_seed(42)
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -70,10 +70,8 @@ def test_policy_instantiation():
}
# Instantiate policy
policy = PI0OpenPIPolicy(config)
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
policy = PI0Policy(config)
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Test forward pass with dummy data
batch_size = 1
device = config.device
@@ -23,8 +23,8 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0_openpi import make_pi0_pre_post_processors # noqa: E402
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
@@ -73,24 +73,20 @@ class PI0BaseOriginalConfig:
def instantiate_lerobot_pi0(
from_pretrained: bool = False,
) -> tuple[
PI0OpenPIPolicy,
PI0Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI0OpenPIPolicy.from_pretrained(
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
)
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True)
else:
config = PI0OpenPIConfig(
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
)
policy = PI0OpenPIPolicy(config)
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0Policy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
preprocessor, postprocessor = make_pi0_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)