refactor: consolidate VLM classes into single QwenVL implementation

Remove Qwen2VL, Qwen3VL, Qwen35VL in favor of one QwenVL class that
uses AutoModelForImageTextToText and works with the whole Qwen VL
family. Moves shared _parse_skills_response to BaseVLM and extracts
_build_messages/_prepare_inputs/_decode helpers to reduce duplication.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-03-30 20:37:09 +02:00
parent 2545f1a8ed
commit 63fad12e8d
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# VLM Interface (Abstract Base Class for Modularity)
import json import json
import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
@@ -27,75 +26,9 @@ from lerobot.utils.constants import (
format_subtask_labels_section, format_subtask_labels_section,
) )
logger = logging.getLogger(__name__)
class BaseVLM(ABC): DEFAULT_MODEL = "Qwen/Qwen3.5-27B"
"""
Abstract base class for Vision-Language Models.
To add a new VLM:
1. Create a subclass of BaseVLM
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
3. Register it in the VLM_REGISTRY dictionary
"""
@abstractmethod
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
"""Initialize the VLM with model name, device, and dtype."""
pass
@abstractmethod
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""
Segment a video into atomic skills.
Args:
video_path: Path to the video file
episode_duration: Total duration of the episode in seconds
coarse_goal: Optional high-level task description
subtask_labels: If provided, model must choose only from these labels (closed vocabulary)
Returns:
List of Skill objects representing atomic manipulation skills
"""
pass
@abstractmethod
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""
Segment multiple videos into atomic skills in a single batch.
Args:
video_paths: List of paths to video files
episode_durations: List of episode durations in seconds
coarse_goal: Optional high-level task description
Returns:
List of skill lists, one for each video
"""
pass
def _unpack_video_inputs(
video_inputs: list | None,
) -> tuple[list | None, list[dict] | None]:
"""Unpack (tensor, metadata) tuples returned by process_vision_info with return_video_metadata=True."""
if not video_inputs:
return None, None
videos = [v[0] for v in video_inputs]
metadata = [v[1] for v in video_inputs]
return videos, metadata
def create_skill_segmentation_prompt( def create_skill_segmentation_prompt(
@@ -103,9 +36,7 @@ def create_skill_segmentation_prompt(
subtask_labels: list[str] | None = None, subtask_labels: list[str] | None = None,
duration_seconds: float | None = None, duration_seconds: float | None = None,
) -> str: ) -> str:
"""Create the prompt for skill segmentation using the template from constants. """Create the prompt for skill segmentation using the template from constants."""
duration_seconds is required. When subtask_labels is provided, uses closed-vocabulary section.
"""
if duration_seconds is None: if duration_seconds is None:
raise ValueError("duration_seconds is required for skill segmentation prompt") raise ValueError("duration_seconds is required for skill segmentation prompt")
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else "" goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
@@ -119,29 +50,21 @@ def create_skill_segmentation_prompt(
) )
# Qwen2-VL Implementation class BaseVLM(ABC):
"""
Abstract base class for Vision-Language Models used in skill segmentation.
To add a new VLM family:
1. Subclass BaseVLM
2. Implement __init__, segment_skills, and segment_skills_batch
3. Register it in get_vlm()
"""
class Qwen2VL(BaseVLM): @abstractmethod
"""Qwen2-VL model for skill segmentation."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info pass
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
print(f"Loading Qwen2-VL model: {model_name}...")
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print(f" Model loaded successfully on {device}")
@abstractmethod
def segment_skills( def segment_skills(
self, self,
video_path: Path, video_path: Path,
@@ -149,48 +72,10 @@ class Qwen2VL(BaseVLM):
coarse_goal: str | None = None, coarse_goal: str | None = None,
subtask_labels: list[str] | None = None, subtask_labels: list[str] | None = None,
) -> list[Skill]: ) -> list[Skill]:
"""Segment video into skills using Qwen2-VL.""" """Segment a single video into atomic skills."""
prompt = create_skill_segmentation_prompt( pass
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
@abstractmethod
def segment_skills_batch( def segment_skills_batch(
self, self,
video_paths: list[Path], video_paths: list[Path],
@@ -198,69 +83,11 @@ class Qwen2VL(BaseVLM):
coarse_goal: str | None = None, coarse_goal: str | None = None,
subtask_labels: list[str] | None = None, subtask_labels: list[str] | None = None,
) -> list[list[Skill]]: ) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen2-VL in a batch.""" """Segment multiple videos into atomic skills in a single batch."""
all_messages = [] pass
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
all_texts = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
videos=all_video_inputs or None,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
print(f"Warning: No skills parsed from response for video {idx}")
all_skills.append(skills)
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]: def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects.""" """Parse JSON skill list from VLM response text."""
# Extract JSON from response
if "```json" in response: if "```json" in response:
response = response.split("```json")[1].split("```")[0] response = response.split("```json")[1].split("```")[0]
elif "```" in response: elif "```" in response:
@@ -272,7 +99,6 @@ class Qwen2VL(BaseVLM):
if isinstance(skills_data, list): if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data] return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError: except json.JSONDecodeError:
# Try to find JSON object in response
match = re.search(r"\{.*\}", response, re.DOTALL) match = re.search(r"\{.*\}", response, re.DOTALL)
if match: if match:
try: try:
@@ -280,219 +106,40 @@ class Qwen2VL(BaseVLM):
skills_data = data.get("skills", []) skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data] return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
excerpt = response[:200] raise ValueError(f"Could not parse JSON from VLM response: {response[:200]}...") from e
raise ValueError(
f"Could not parse JSON from VLM response (fallback failed): {excerpt}..."
) from e
raise ValueError(f"Could not parse skills from response: {response[:200]}...") raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# Qwen3-VL Implementation (MoE variant) class QwenVL(BaseVLM):
"""Qwen VL model for skill segmentation (default: Qwen3.5 series).
Uses qwen-vl-utils for video processing and the HuggingFace transformers
class Qwen3VL(BaseVLM): Qwen3VLProcessor pipeline. Requires transformers >= 5.4.0 for correct
"""Qwen3-VL MoE model for skill segmentation.""" video position embeddings.
"""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration from transformers import AutoModelForImageTextToText, AutoProcessor
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
self.process_vision_info = process_vision_info self.process_vision_info = process_vision_info
print(f"Loading Qwen3-VL model: {model_name}...") logger.info(f"Loading model: {model_name}...")
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( self.model = AutoModelForImageTextToText.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print(f" Model loaded successfully on {device}")
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen3-VL."""
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
videos, video_metadata = _unpack_video_inputs(video_inputs)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
all_messages = []
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
all_texts = []
all_video_tuples = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
all_texts.append(text)
all_video_tuples.extend(video_inputs or [])
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None)
inputs = self.processor(
text=all_texts,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
print(f"Warning: No skills parsed from response for video {idx}")
all_skills.append(skills)
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
class Qwen35VL(BaseVLM):
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
print(f"Loading Qwen3.5-VL model: {model_name}...")
self.model = Qwen3_5ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
) )
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.processor.tokenizer.padding_side = "left" self.processor.tokenizer.padding_side = "left"
print(f" Model loaded successfully on {device}")
def segment_skills( logger.info(f"Model loaded on {device}")
self,
video_path: Path, def _build_messages(self, video_path: Path, episode_duration: float, prompt: str) -> list[dict]:
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen3.5-VL."""
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [ return [
{"role": "system", "content": [{"type": "text", "text": prompt}]}, {"role": "system", "content": [{"type": "text", "text": prompt}]},
{ {
"role": "user", "role": "user",
@@ -500,18 +147,28 @@ class Qwen35VL(BaseVLM):
{"type": "video", "video": str(video_path), "fps": 1.0}, {"type": "video", "video": str(video_path), "fps": 1.0},
{ {
"type": "text", "type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.", "text": (
f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). "
f"Segment into atomic skills. Last skill must end at {episode_duration:.1f}."
),
}, },
], ],
}, },
] ]
def _prepare_inputs(self, messages: list[dict]) -> dict:
"""Tokenize a single message and return processor inputs on device."""
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
) )
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True) image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
videos, video_metadata = _unpack_video_inputs(video_inputs)
inputs = self.processor( videos, video_metadata = None, None
if video_inputs:
videos = [v[0] for v in video_inputs]
video_metadata = [v[1] for v in video_inputs]
return self.processor(
text=[text], text=[text],
images=image_inputs, images=image_inputs,
videos=videos, videos=videos,
@@ -522,15 +179,33 @@ class Qwen35VL(BaseVLM):
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(self.device) ).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
response = self.processor.batch_decode( def _decode(self, inputs, generated_ids) -> list[str]:
return self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)], [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[0].strip() )
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
messages = self._build_messages(video_path, episode_duration, prompt)
inputs = self._prepare_inputs(messages)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self._decode(inputs, generated_ids)[0].strip()
return self._parse_skills_response(response) return self._parse_skills_response(response)
def segment_skills_batch( def segment_skills_batch(
@@ -540,38 +215,25 @@ class Qwen35VL(BaseVLM):
coarse_goal: str | None = None, coarse_goal: str | None = None,
subtask_labels: list[str] | None = None, subtask_labels: list[str] | None = None,
) -> list[list[Skill]]: ) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3.5-VL in a batch.""" all_texts = []
all_messages = [] all_video_tuples: list[tuple] = []
for video_path, duration in zip(video_paths, episode_durations, strict=True): for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration) prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" messages = self._build_messages(video_path, duration, prompt)
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
all_texts = []
all_video_tuples = []
for messages in all_messages:
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
) )
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True) _image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
all_texts.append(text) all_texts.append(text)
all_video_tuples.extend(video_inputs or []) all_video_tuples.extend(video_inputs or [])
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None) videos, video_metadata = None, None
if all_video_tuples:
videos = [v[0] for v in all_video_tuples]
video_metadata = [v[1] for v in all_video_tuples]
inputs = self.processor( inputs = self.processor(
text=all_texts, text=all_texts,
videos=videos, videos=videos,
@@ -584,94 +246,26 @@ class Qwen35VL(BaseVLM):
).to(self.device) ).to(self.device)
with torch.no_grad(): with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self.processor.batch_decode( responses = self._decode(inputs, generated_ids)
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
all_skills = [] all_skills = []
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
try: try:
skills = self._parse_skills_response(response.strip()) skills = self._parse_skills_response(response.strip())
if not skills: if not skills:
print(f"Warning: No skills parsed from response for video {idx}") logger.warning(f"No skills parsed for video {idx}")
all_skills.append(skills) all_skills.append(skills)
except Exception as e: except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}") logger.warning(f"Failed to parse response for video {idx}: {e}")
all_skills.append([]) all_skills.append([])
return all_skills return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# VLM Registry - Add new VLMs here
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
# Qwen2-VL variants
"Qwen/Qwen2-VL-2B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-7B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
# Qwen3-VL variants (MoE)
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
"Qwen/Qwen3.5-27B": Qwen35VL,
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
}
def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM: def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM:
""" """Create a VLM instance. Defaults to QwenVL which supports the Qwen3.5 series."""
Factory function to get the appropriate VLM based on model name. return QwenVL(model_name, device, torch_dtype)
Args:
model_name: HuggingFace model identifier
device: Device to load model on
torch_dtype: Data type for model weights
Returns:
Initialized VLM instance
Raises:
ValueError: If model is not in registry
"""
# Check exact match first
if model_name in VLM_REGISTRY:
return VLM_REGISTRY[model_name](model_name, device, torch_dtype)
# Check for partial matches (e.g., "qwen2" in model name)
model_lower = model_name.lower()
if "qwen3.5" in model_lower:
return Qwen35VL(model_name, device, torch_dtype)
if "qwen3" in model_lower:
return Qwen3VL(model_name, device, torch_dtype)
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
return Qwen2VL(model_name, device, torch_dtype)
raise ValueError(
f"Unknown model: {model_name}. "
f"Supported models: {list(VLM_REGISTRY.keys())}. "
"Or implement a new VLM class inheriting from BaseVLM."
)