mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy
- Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions.
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig
|
||||
from .pi0.configuration_pi0openpi import PI0Config
|
||||
from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
@@ -24,7 +24,7 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0OpenPIConfig",
|
||||
"PI0Config",
|
||||
"PI05OpenPIConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
|
||||
@@ -31,7 +31,7 @@ from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0Config
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -140,7 +140,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0OpenPIConfig(**kwargs)
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05OpenPIConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
@@ -150,7 +150,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "pi0_openpi":
|
||||
return PI0OpenPIConfig(**kwargs)
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05_openpi":
|
||||
return PI05OpenPIConfig(**kwargs)
|
||||
else:
|
||||
@@ -272,16 +272,16 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0OpenPIConfig):
|
||||
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0_openpi import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_openpi_pre_post_processors(
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05OpenPIConfig):
|
||||
from lerobot.policies.pi05_openpi.processor_pi05openpi import make_pi05_openpi_pre_post_processors
|
||||
from lerobot.policies.pi05.processor_pi05openpi import make_pi05_openpi_pre_post_processors
|
||||
|
||||
processors = make_pi05_openpi_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi0openpi import PI0OpenPIConfig
|
||||
from .modeling_pi0openpi import PI0OpenPIPolicy
|
||||
from .processor_pi0_openpi import make_pi0_openpi_pre_post_processors
|
||||
from .configuration_pi0openpi import PI0Config
|
||||
from .modeling_pi0openpi import PI0Policy
|
||||
from .processor_pi0_openpi import make_pi0_pre_post_processors
|
||||
|
||||
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy", "make_pi0_openpi_pre_post_processors"]
|
||||
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0OpenPIConfig(PreTrainedConfig):
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
@@ -490,7 +490,7 @@ class PaliGemmaWithExpertModel(
|
||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI0 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI0OpenPIConfig):
|
||||
def __init__(self, config: PI0Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -839,15 +839,15 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
class PI0Policy(PreTrainedPolicy):
|
||||
"""PI0 OpenPI Policy for LeRobot."""
|
||||
|
||||
config_class = PI0OpenPIConfig
|
||||
config_class = PI0Config
|
||||
name = "pi0"
|
||||
|
||||
def __init__( # see lerobot pi0 `__init__`
|
||||
self,
|
||||
config: PI0OpenPIConfig,
|
||||
config: PI0Config,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
+6
-6
@@ -18,7 +18,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0.configuration_pi0openpi import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -35,8 +35,8 @@ from lerobot.processor import (
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_openpi_new_line_processor")
|
||||
class Pi0OpenPINewLineProcessor(ComplementaryDataProcessorStep):
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Ensures that the task description string ends with a newline character.
|
||||
|
||||
@@ -92,8 +92,8 @@ class Pi0OpenPINewLineProcessor(ComplementaryDataProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_openpi_pre_post_processors(
|
||||
config: PI0OpenPIConfig,
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -128,7 +128,7 @@ def make_pi0_openpi_pre_post_processors(
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0OpenPINewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
@@ -31,7 +31,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
|
||||
+2
-2
@@ -6,8 +6,8 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import pad_vector
|
||||
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pi05.modeling_pi05openpi import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -14,10 +14,10 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0_openpi import ( # noqa: E402
|
||||
PI0OpenPIConfig,
|
||||
PI0OpenPIPolicy,
|
||||
make_pi0_openpi_pre_post_processors, # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
PI0Policy,
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
@@ -30,7 +30,7 @@ def test_policy_instantiation():
|
||||
# Create config
|
||||
|
||||
set_seed(42)
|
||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -70,10 +70,8 @@ def test_policy_instantiation():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
policy = PI0Policy(config)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = config.device
|
||||
|
||||
@@ -23,8 +23,8 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
|
||||
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0_openpi import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
@@ -73,24 +73,20 @@ class PI0BaseOriginalConfig:
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0OpenPIPolicy,
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
|
||||
)
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True)
|
||||
else:
|
||||
config = PI0OpenPIConfig(
|
||||
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
|
||||
)
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
Reference in New Issue
Block a user