mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
fix autodocstring
This commit is contained in:
+19
-9
@@ -45,6 +45,16 @@ from .configuration_gemma import GemmaConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
try:
|
||||
return auto_docstring(func)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return func
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
@@ -391,7 +401,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
@@ -422,7 +432,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
@@ -449,7 +459,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -567,7 +577,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
@@ -601,7 +611,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -685,7 +695,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@safe_auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
@@ -716,7 +726,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -792,7 +802,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@@ -817,7 +827,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@safe_auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
Reference in New Issue
Block a user