mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(pi): avoid peak RAM in PiGemma construction by freeing replaced submodules (#3454)
Co-Authored-By: Daiki Kamata <daiki.kamata@access-company.com> Co-Authored-By: Jack Vial <jackvial@users.noreply.github.com> Co-Authored-By: Ajay Anubolu <AjAnubolu@users.noreply.github.com> Co-Authored-By: Finn F. <F-Fer@users.noreply.github.com>
This commit is contained in:
@@ -227,6 +227,7 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||||
if use_adarms[0]:
|
if use_adarms[0]:
|
||||||
text_config = self.paligemma.config.text_config
|
text_config = self.paligemma.config.text_config
|
||||||
|
del self.paligemma.model.language_model
|
||||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
|
||||||
|
del self.layers
|
||||||
|
del self.norm
|
||||||
# if not getattr(config, "use_adarms", False):
|
# if not getattr(config, "use_adarms", False):
|
||||||
# return
|
# return
|
||||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||||
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
del self.model
|
||||||
self.model = PiGemmaModel(config)
|
self.model = PiGemmaModel(config)
|
||||||
|
|
||||||
|
|
||||||
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.language_model
|
||||||
self.language_model = PiGemmaModel(config.text_config)
|
self.language_model = PiGemmaModel(config.text_config)
|
||||||
|
|
||||||
|
|
||||||
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.model
|
||||||
self.model = PaliGemmaModelWithPiGemma(config)
|
self.model = PaliGemmaModelWithPiGemma(config)
|
||||||
|
|
||||||
# Make modules available through conditional class for BC
|
# Make modules available through conditional class for BC
|
||||||
|
|||||||
Reference in New Issue
Block a user