From 05a5223885bcd36064fc1a967620329696595a76 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 24 Apr 2026 17:50:12 +0200 Subject: [PATCH] fix(pi): avoid peak RAM in PiGemma construction by freeing replaced submodules (#3454) Co-Authored-By: Daiki Kamata Co-Authored-By: Jack Vial Co-Authored-By: Ajay Anubolu Co-Authored-By: Finn F. --- src/lerobot/policies/pi0_fast/modeling_pi0_fast.py | 1 + src/lerobot/policies/pi_gemma.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index a49828ad1..0bc301609 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -227,6 +227,7 @@ class PI0FastPaliGemma(nn.Module): # forward(..., adarms_cond=...) is supported (same as pi0/pi05). if use_adarms[0]: text_config = self.paligemma.config.text_config + del self.paligemma.model.language_model self.paligemma.model.language_model = PiGemmaModel(text_config) self.to_bfloat16_for_selected_params(precision) diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py index 05f031d08..9986f9b79 100644 --- a/src/lerobot/policies/pi_gemma.py +++ b/src/lerobot/policies/pi_gemma.py @@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc] def __init__(self, config: GemmaConfig, **kwargs): super().__init__(config, **kwargs) + # Free parent-allocated layers/norm before replacing to avoid ~2x peak memory. + del self.layers + del self.norm # if not getattr(config, "use_adarms", False): # return cond_dim = getattr(config, "adarms_cond_dim", None) @@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc] def __init__(self, config: GemmaConfig, **kwargs): super().__init__(config, **kwargs) + del self.model self.model = PiGemmaModel(config) @@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel): def __init__(self, config): super().__init__(config) + del self.language_model self.language_model = PiGemmaModel(config.text_config) @@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera def __init__(self, config): super().__init__(config) + del self.model self.model = PaliGemmaModelWithPiGemma(config) # Make modules available through conditional class for BC