mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix(annotate): try AutoModelForImageTextToText first, fall back to AutoModelForVision2Seq
Newer transformers versions renamed/removed AutoModelForVision2Seq in favour of AutoModelForImageTextToText for VL models. Try the new name first and fall back gracefully so the transformers backend works on both transformers 4.45-4.5x and 5.x. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -166,11 +166,22 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
|||||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||||
try:
|
try:
|
||||||
import torch # type: ignore[import-not-found]
|
import torch # type: ignore[import-not-found]
|
||||||
from transformers import AutoModelForVision2Seq, AutoProcessor # type: ignore[import-not-found]
|
import transformers # type: ignore[import-not-found]
|
||||||
|
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||||
|
auto_cls = (
|
||||||
|
getattr(transformers, "AutoModelForImageTextToText", None)
|
||||||
|
or getattr(transformers, "AutoModelForVision2Seq", None)
|
||||||
|
)
|
||||||
|
if auto_cls is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||||
|
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||||
|
"for VL models."
|
||||||
|
)
|
||||||
processor = AutoProcessor.from_pretrained(config.model_id)
|
processor = AutoProcessor.from_pretrained(config.model_id)
|
||||||
model = AutoModelForVision2Seq.from_pretrained(config.model_id, torch_dtype="auto")
|
model = auto_cls.from_pretrained(config.model_id, torch_dtype="auto")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user