From de104936bf29c7c2374b4ea7a8aef3abd7219b9b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 27 Apr 2026 22:31:56 +0200 Subject: [PATCH] fix(annotate): try AutoModelForImageTextToText first, fall back to AutoModelForVision2Seq Newer transformers versions renamed/removed AutoModelForVision2Seq in favour of AutoModelForImageTextToText for VL models. Try the new name first and fall back gracefully so the transformers backend works on both transformers 4.45-4.5x and 5.x. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../annotations/steerable_pipeline/vlm_client.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py index 60e86627b..734e33ffd 100644 --- a/src/lerobot/annotations/steerable_pipeline/vlm_client.py +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -166,11 +166,22 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient: def _make_transformers_client(config: VlmConfig) -> VlmClient: try: import torch # type: ignore[import-not-found] - from transformers import AutoModelForVision2Seq, AutoProcessor # type: ignore[import-not-found] + import transformers # type: ignore[import-not-found] + from transformers import AutoProcessor # type: ignore[import-not-found] except ImportError as exc: raise ImportError("transformers + torch are required for backend='transformers'.") from exc + auto_cls = ( + getattr(transformers, "AutoModelForImageTextToText", None) + or getattr(transformers, "AutoModelForVision2Seq", None) + ) + if auto_cls is None: + raise ImportError( + "Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this " + "transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) " + "for VL models." + ) processor = AutoProcessor.from_pretrained(config.model_id) - model = AutoModelForVision2Seq.from_pretrained(config.model_id, torch_dtype="auto") + model = auto_cls.from_pretrained(config.model_id, torch_dtype="auto") model.eval() def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]: