fix(pi): avoid peak RAM in PiGemma construction by freeing replaced submodules (#3454)

Co-Authored-By: Daiki Kamata <daiki.kamata@access-company.com>
Co-Authored-By: Jack Vial <jackvial@users.noreply.github.com>
Co-Authored-By: Ajay Anubolu <AjAnubolu@users.noreply.github.com>
Co-Authored-By: Finn F. <F-Fer@users.noreply.github.com>
This commit is contained in:
Steven Palma
2026-04-24 17:50:12 +02:00
committed by GitHub
parent 580d818aa9
commit 05a5223885
2 changed files with 7 additions and 0 deletions
@@ -227,6 +227,7 @@ class PI0FastPaliGemma(nn.Module):
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]:
text_config = self.paligemma.config.text_config
del self.paligemma.model.language_model
self.paligemma.model.language_model = PiGemmaModel(text_config)
self.to_bfloat16_for_selected_params(precision)
+6
View File
@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
del self.layers
del self.norm
# if not getattr(config, "use_adarms", False):
# return
cond_dim = getattr(config, "adarms_cond_dim", None)
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
del self.model
self.model = PiGemmaModel(config)
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
def __init__(self, config):
super().__init__(config)
del self.language_model
self.language_model = PiGemmaModel(config.text_config)
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
def __init__(self, config):
super().__init__(config)
del self.model
self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC