mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
load from pretrained_path
This commit is contained in:
@@ -66,6 +66,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Pretrained model loading
|
||||
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
|
||||
|
||||
# Optimizer settings: see openpi `AdamW` and
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
||||
@@ -875,8 +875,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Create default config
|
||||
config = cls.config_class()
|
||||
# Use provided config if available, otherwise create default config
|
||||
config = kwargs.get("config", cls.config_class())
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
|
||||
@@ -63,6 +63,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Pretrained model loading
|
||||
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
|
||||
|
||||
# Optimizer settings: see openpi `AdamW` and
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
||||
@@ -894,8 +894,8 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Create default config
|
||||
config = cls.config_class()
|
||||
# Use provided config if available, otherwise create default config
|
||||
config = kwargs.get("config", cls.config_class())
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
|
||||
Reference in New Issue
Block a user