mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
load from pretrained_path
This commit is contained in:
@@ -66,6 +66,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
|||||||
compile_mode: str = "max-autotune" # Torch compile mode
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||||
|
|
||||||
|
# Pretrained model loading
|
||||||
|
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
|
||||||
|
|
||||||
# Optimizer settings: see openpi `AdamW` and
|
# Optimizer settings: see openpi `AdamW` and
|
||||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|||||||
@@ -875,8 +875,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
|||||||
if pretrained_name_or_path is None:
|
if pretrained_name_or_path is None:
|
||||||
raise ValueError("pretrained_name_or_path is required")
|
raise ValueError("pretrained_name_or_path is required")
|
||||||
|
|
||||||
# Create default config
|
# Use provided config if available, otherwise create default config
|
||||||
config = cls.config_class()
|
config = kwargs.get("config", cls.config_class())
|
||||||
|
|
||||||
# Initialize model without loading weights
|
# Initialize model without loading weights
|
||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
|||||||
compile_mode: str = "max-autotune" # Torch compile mode
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||||
|
|
||||||
|
# Pretrained model loading
|
||||||
|
pretrained_path: str | None = None # Path or repo_id to load pretrained weights from
|
||||||
|
|
||||||
# Optimizer settings: see openpi `AdamW` and
|
# Optimizer settings: see openpi `AdamW` and
|
||||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|||||||
@@ -894,8 +894,8 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
|||||||
if pretrained_name_or_path is None:
|
if pretrained_name_or_path is None:
|
||||||
raise ValueError("pretrained_name_or_path is required")
|
raise ValueError("pretrained_name_or_path is required")
|
||||||
|
|
||||||
# Create default config
|
# Use provided config if available, otherwise create default config
|
||||||
config = cls.config_class()
|
config = kwargs.get("config", cls.config_class())
|
||||||
|
|
||||||
# Initialize model without loading weights
|
# Initialize model without loading weights
|
||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user