diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index 2b5316fae..e33efe5c3 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): class Florence2LanguageModel(Florence2LanguagePreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: Florence2LanguageConfig): super().__init__(config) @@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel): class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "model.encoder.embed_tokens.weight": "model.shared.weight", + "model.decoder.embed_tokens.weight": "model.shared.weight", + } _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: Florence2LanguageConfig): @@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r""" FLORENCE2_START_DOCSTRING, ) class Florence2ForConditionalGeneration(Florence2PreTrainedModel): - _tied_weights_keys = [ - "language_model.encoder.embed_tokens.weight", - "language_model.decoder.embed_tokens.weight", - "language_model.lm_head.weight", - ] + _tied_weights_keys = { + "language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight", + "language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight", + } def __init__(self, config: Florence2Config): super().__init__(config)