mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy
- Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions.
This commit is contained in:
@@ -14,10 +14,10 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0_openpi import ( # noqa: E402
|
||||
PI0OpenPIConfig,
|
||||
PI0OpenPIPolicy,
|
||||
make_pi0_openpi_pre_post_processors, # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
PI0Policy,
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
@@ -30,7 +30,7 @@ def test_policy_instantiation():
|
||||
# Create config
|
||||
|
||||
set_seed(42)
|
||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -70,10 +70,8 @@ def test_policy_instantiation():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
policy = PI0Policy(config)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = config.device
|
||||
|
||||
@@ -23,8 +23,8 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
|
||||
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0_openpi import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
@@ -73,24 +73,20 @@ class PI0BaseOriginalConfig:
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0OpenPIPolicy,
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
|
||||
)
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True)
|
||||
else:
|
||||
config = PI0OpenPIConfig(
|
||||
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
|
||||
)
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
Reference in New Issue
Block a user