Compare commits

...

3 Commits

Author SHA1 Message Date
Jade Choghari 62d23b0986 add for rest of policies 2026-02-27 16:32:33 +01:00
Jade Choghari a6a2f3662a Merge branch 'main' into speedup-pi05-launch 2026-02-27 18:12:21 +03:00
Jeremiah Coholich 49444652c6 speedup pi-05 modeling loading by 72s 2026-02-20 15:41:44 -05:00
3 changed files with 24 additions and 3 deletions
+7
View File
@@ -995,6 +995,13 @@ class PI0Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
@@ -967,6 +967,13 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
@@ -895,6 +895,13 @@ class PI0FastPolicy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict