mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy
- Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions.
This commit is contained in:
@@ -14,10 +14,10 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0_openpi import ( # noqa: E402
|
||||
PI0OpenPIConfig,
|
||||
PI0OpenPIPolicy,
|
||||
make_pi0_openpi_pre_post_processors, # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
PI0Policy,
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
@@ -30,7 +30,7 @@ def test_policy_instantiation():
|
||||
# Create config
|
||||
|
||||
set_seed(42)
|
||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -70,10 +70,8 @@ def test_policy_instantiation():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
policy = PI0Policy(config)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = config.device
|
||||
|
||||
Reference in New Issue
Block a user