fix autodocstring

This commit is contained in:
Pepijn
2025-09-10 20:24:39 +02:00
parent 8c5fe10d6c
commit 3bc3bf0391
@@ -45,6 +45,16 @@ from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
try:
return auto_docstring(func)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return func
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
@@ -391,7 +401,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
@safe_auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
@@ -422,7 +432,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@auto_docstring
@safe_auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
@@ -449,7 +459,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -567,7 +577,7 @@ class GemmaModel(GemmaPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring
@safe_auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -601,7 +611,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
return self.model
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -685,7 +695,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
@@ -716,7 +726,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -792,7 +802,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
)
@auto_docstring
@safe_auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -817,7 +827,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,