From 3bc3bf0391e0ab9500af2b208ac5c281dfbf1634 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 10 Sep 2025 20:24:39 +0200 Subject: [PATCH] fix autodocstring --- .../models/gemma/modeling_gemma.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py index 6ddb6c337..b596bcad5 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -45,6 +45,16 @@ from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + try: + return auto_docstring(func) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return func + + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): super().__init__() @@ -391,7 +401,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): return outputs -@auto_docstring +@safe_auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" @@ -422,7 +432,7 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) -@auto_docstring +@safe_auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) @@ -449,7 +459,7 @@ class GemmaModel(GemmaPreTrainedModel): self.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -567,7 +577,7 @@ class GemmaModel(GemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring +@safe_auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -601,7 +611,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -685,7 +695,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). @@ -716,7 +726,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -792,7 +802,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ) -@auto_docstring +@safe_auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -817,7 +827,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None,