Compare commits

...

3 Commits

Author SHA1 Message Date
Jade Choghari 62d23b0986 add for rest of policies 2026-02-27 16:32:33 +01:00
Jade Choghari a6a2f3662a Merge branch 'main' into speedup-pi05-launch 2026-02-27 18:12:21 +03:00
Jeremiah Coholich 49444652c6 speedup pi-05 modeling loading by 72s 2026-02-20 15:41:44 -05:00
3 changed files with 24 additions and 3 deletions
+8 -1
View File
@@ -995,7 +995,14 @@ class PI0Policy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try:
+8 -1
View File
@@ -967,7 +967,14 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try:
@@ -895,7 +895,14 @@ class PI0FastPolicy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try: