load from pretrained_path

This commit is contained in:
Pepijn
2025-09-13 14:27:07 +02:00
parent b9df1a4ac5
commit af0676f99e
4 changed files with 10 additions and 4 deletions
@@ -66,6 +66,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Pretrained model loading
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
# Optimizer settings: see openpi `AdamW` and
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
@@ -875,8 +875,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Create default config
config = cls.config_class()
# Use provided config if available, otherwise create default config
config = kwargs.get("config", cls.config_class())
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
@@ -63,6 +63,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Pretrained model loading
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
# Optimizer settings: see openpi `AdamW` and
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
@@ -894,8 +894,8 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Create default config
config = cls.config_class()
# Use provided config if available, otherwise create default config
config = kwargs.get("config", cls.config_class())
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs