mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor: consolidate VLM classes into single QwenVL implementation
Remove Qwen2VL, Qwen3VL, Qwen35VL in favor of one QwenVL class that uses AutoModelForImageTextToText and works with the whole Qwen VL family. Moves shared _parse_skills_response to BaseVLM and extracts _build_messages/_prepare_inputs/_decode helpers to reduce duplication. Made-with: Cursor
This commit is contained in:
@@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# VLM Interface (Abstract Base Class for Modularity)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -27,75 +26,9 @@ from lerobot.utils.constants import (
|
||||
format_subtask_labels_section,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseVLM(ABC):
|
||||
"""
|
||||
Abstract base class for Vision-Language Models.
|
||||
|
||||
To add a new VLM:
|
||||
1. Create a subclass of BaseVLM
|
||||
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
|
||||
3. Register it in the VLM_REGISTRY dictionary
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
"""Initialize the VLM with model name, device, and dtype."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills(
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""
|
||||
Segment a video into atomic skills.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
episode_duration: Total duration of the episode in seconds
|
||||
coarse_goal: Optional high-level task description
|
||||
subtask_labels: If provided, model must choose only from these labels (closed vocabulary)
|
||||
|
||||
Returns:
|
||||
List of Skill objects representing atomic manipulation skills
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills_batch(
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
episode_durations: list[float],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""
|
||||
Segment multiple videos into atomic skills in a single batch.
|
||||
|
||||
Args:
|
||||
video_paths: List of paths to video files
|
||||
episode_durations: List of episode durations in seconds
|
||||
coarse_goal: Optional high-level task description
|
||||
|
||||
Returns:
|
||||
List of skill lists, one for each video
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _unpack_video_inputs(
|
||||
video_inputs: list | None,
|
||||
) -> tuple[list | None, list[dict] | None]:
|
||||
"""Unpack (tensor, metadata) tuples returned by process_vision_info with return_video_metadata=True."""
|
||||
if not video_inputs:
|
||||
return None, None
|
||||
videos = [v[0] for v in video_inputs]
|
||||
metadata = [v[1] for v in video_inputs]
|
||||
return videos, metadata
|
||||
DEFAULT_MODEL = "Qwen/Qwen3.5-27B"
|
||||
|
||||
|
||||
def create_skill_segmentation_prompt(
|
||||
@@ -103,9 +36,7 @@ def create_skill_segmentation_prompt(
|
||||
subtask_labels: list[str] | None = None,
|
||||
duration_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""Create the prompt for skill segmentation using the template from constants.
|
||||
duration_seconds is required. When subtask_labels is provided, uses closed-vocabulary section.
|
||||
"""
|
||||
"""Create the prompt for skill segmentation using the template from constants."""
|
||||
if duration_seconds is None:
|
||||
raise ValueError("duration_seconds is required for skill segmentation prompt")
|
||||
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
||||
@@ -119,29 +50,21 @@ def create_skill_segmentation_prompt(
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-VL Implementation
|
||||
class BaseVLM(ABC):
|
||||
"""
|
||||
Abstract base class for Vision-Language Models used in skill segmentation.
|
||||
|
||||
To add a new VLM family:
|
||||
1. Subclass BaseVLM
|
||||
2. Implement __init__, segment_skills, and segment_skills_batch
|
||||
3. Register it in get_vlm()
|
||||
"""
|
||||
|
||||
class Qwen2VL(BaseVLM):
|
||||
"""Qwen2-VL model for skill segmentation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
self.device = device
|
||||
self.model_name = model_name
|
||||
self.process_vision_info = process_vision_info
|
||||
|
||||
print(f"Loading Qwen2-VL model: {model_name}...")
|
||||
|
||||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
print(f" Model loaded successfully on {device}")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills(
|
||||
self,
|
||||
video_path: Path,
|
||||
@@ -149,48 +72,10 @@ class Qwen2VL(BaseVLM):
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using Qwen2-VL."""
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=episode_duration
|
||||
)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
return self._parse_skills_response(response)
|
||||
"""Segment a single video into atomic skills."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills_batch(
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
@@ -198,69 +83,11 @@ class Qwen2VL(BaseVLM):
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
all_texts = []
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
videos=all_video_inputs or None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for idx, response in enumerate(responses):
|
||||
try:
|
||||
skills = self._parse_skills_response(response.strip())
|
||||
if not skills:
|
||||
print(f"Warning: No skills parsed from response for video {idx}")
|
||||
all_skills.append(skills)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse response for video {idx}: {e}")
|
||||
all_skills.append([])
|
||||
|
||||
return all_skills
|
||||
"""Segment multiple videos into atomic skills in a single batch."""
|
||||
pass
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
# Extract JSON from response
|
||||
"""Parse JSON skill list from VLM response text."""
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
@@ -272,7 +99,6 @@ class Qwen2VL(BaseVLM):
|
||||
if isinstance(skills_data, list):
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in response
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
@@ -280,219 +106,40 @@ class Qwen2VL(BaseVLM):
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError as e:
|
||||
excerpt = response[:200]
|
||||
raise ValueError(
|
||||
f"Could not parse JSON from VLM response (fallback failed): {excerpt}..."
|
||||
) from e
|
||||
raise ValueError(f"Could not parse JSON from VLM response: {response[:200]}...") from e
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# Qwen3-VL Implementation (MoE variant)
|
||||
class QwenVL(BaseVLM):
|
||||
"""Qwen VL model for skill segmentation (default: Qwen3.5 series).
|
||||
|
||||
|
||||
class Qwen3VL(BaseVLM):
|
||||
"""Qwen3-VL MoE model for skill segmentation."""
|
||||
Uses qwen-vl-utils for video processing and the HuggingFace transformers
|
||||
Qwen3VLProcessor pipeline. Requires transformers >= 5.4.0 for correct
|
||||
video position embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
self.device = device
|
||||
self.model_name = model_name
|
||||
self.process_vision_info = process_vision_info
|
||||
|
||||
print(f"Loading Qwen3-VL model: {model_name}...")
|
||||
logger.info(f"Loading model: {model_name}...")
|
||||
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
print(f" Model loaded successfully on {device}")
|
||||
|
||||
def segment_skills(
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using Qwen3-VL."""
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=episode_duration
|
||||
)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
|
||||
videos, video_metadata = _unpack_video_inputs(video_inputs)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=videos,
|
||||
videos_kwargs={
|
||||
"video_metadata": video_metadata,
|
||||
"do_sample_frames": False,
|
||||
},
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
episode_durations: list[float],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
all_texts = []
|
||||
all_video_tuples = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
|
||||
all_texts.append(text)
|
||||
all_video_tuples.extend(video_inputs or [])
|
||||
|
||||
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None)
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
videos=videos,
|
||||
videos_kwargs={
|
||||
"video_metadata": video_metadata,
|
||||
"do_sample_frames": False,
|
||||
},
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for idx, response in enumerate(responses):
|
||||
try:
|
||||
skills = self._parse_skills_response(response.strip())
|
||||
if not skills:
|
||||
print(f"Warning: No skills parsed from response for video {idx}")
|
||||
all_skills.append(skills)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse response for video {idx}: {e}")
|
||||
all_skills.append([])
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
skills_data = data.get("skills", data)
|
||||
if isinstance(skills_data, list):
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
|
||||
|
||||
|
||||
class Qwen35VL(BaseVLM):
|
||||
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
|
||||
|
||||
self.device = device
|
||||
self.model_name = model_name
|
||||
self.process_vision_info = process_vision_info
|
||||
|
||||
print(f"Loading Qwen3.5-VL model: {model_name}...")
|
||||
|
||||
self.model = Qwen3_5ForConditionalGeneration.from_pretrained(
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
print(f" Model loaded successfully on {device}")
|
||||
|
||||
def segment_skills(
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using Qwen3.5-VL."""
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=episode_duration
|
||||
)
|
||||
logger.info(f"Model loaded on {device}")
|
||||
|
||||
def _build_messages(self, video_path: Path, episode_duration: float, prompt: str) -> list[dict]:
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
messages = [
|
||||
return [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -500,18 +147,28 @@ class Qwen35VL(BaseVLM):
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
|
||||
"text": (
|
||||
f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). "
|
||||
f"Segment into atomic skills. Last skill must end at {episode_duration:.1f}."
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def _prepare_inputs(self, messages: list[dict]) -> dict:
|
||||
"""Tokenize a single message and return processor inputs on device."""
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
|
||||
videos, video_metadata = _unpack_video_inputs(video_inputs)
|
||||
inputs = self.processor(
|
||||
|
||||
videos, video_metadata = None, None
|
||||
if video_inputs:
|
||||
videos = [v[0] for v in video_inputs]
|
||||
video_metadata = [v[1] for v in video_inputs]
|
||||
|
||||
return self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=videos,
|
||||
@@ -522,15 +179,33 @@ class Qwen35VL(BaseVLM):
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
def _decode(self, inputs, generated_ids) -> list[str]:
|
||||
return self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0].strip()
|
||||
)
|
||||
|
||||
def segment_skills(
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=episode_duration
|
||||
)
|
||||
messages = self._build_messages(video_path, episode_duration, prompt)
|
||||
inputs = self._prepare_inputs(messages)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self._decode(inputs, generated_ids)[0].strip()
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
@@ -540,38 +215,25 @@ class Qwen35VL(BaseVLM):
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
|
||||
all_messages = []
|
||||
all_texts = []
|
||||
all_video_tuples: list[tuple] = []
|
||||
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
messages = self._build_messages(video_path, duration, prompt)
|
||||
|
||||
all_texts = []
|
||||
all_video_tuples = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
|
||||
_image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
|
||||
all_texts.append(text)
|
||||
all_video_tuples.extend(video_inputs or [])
|
||||
|
||||
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None)
|
||||
videos, video_metadata = None, None
|
||||
if all_video_tuples:
|
||||
videos = [v[0] for v in all_video_tuples]
|
||||
video_metadata = [v[1] for v in all_video_tuples]
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
videos=videos,
|
||||
@@ -584,94 +246,26 @@ class Qwen35VL(BaseVLM):
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
responses = self._decode(inputs, generated_ids)
|
||||
|
||||
all_skills = []
|
||||
for idx, response in enumerate(responses):
|
||||
try:
|
||||
skills = self._parse_skills_response(response.strip())
|
||||
if not skills:
|
||||
print(f"Warning: No skills parsed from response for video {idx}")
|
||||
logger.warning(f"No skills parsed for video {idx}")
|
||||
all_skills.append(skills)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse response for video {idx}: {e}")
|
||||
logger.warning(f"Failed to parse response for video {idx}: {e}")
|
||||
all_skills.append([])
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
skills_data = data.get("skills", data)
|
||||
if isinstance(skills_data, list):
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# VLM Registry - Add new VLMs here
|
||||
|
||||
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
||||
# Qwen2-VL variants
|
||||
"Qwen/Qwen2-VL-2B-Instruct": Qwen2VL,
|
||||
"Qwen/Qwen2-VL-7B-Instruct": Qwen2VL,
|
||||
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
|
||||
# Qwen3-VL variants (MoE)
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
||||
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
|
||||
"Qwen/Qwen3.5-27B": Qwen35VL,
|
||||
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
|
||||
}
|
||||
|
||||
|
||||
def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM:
|
||||
"""
|
||||
Factory function to get the appropriate VLM based on model name.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model identifier
|
||||
device: Device to load model on
|
||||
torch_dtype: Data type for model weights
|
||||
|
||||
Returns:
|
||||
Initialized VLM instance
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not in registry
|
||||
"""
|
||||
# Check exact match first
|
||||
if model_name in VLM_REGISTRY:
|
||||
return VLM_REGISTRY[model_name](model_name, device, torch_dtype)
|
||||
|
||||
# Check for partial matches (e.g., "qwen2" in model name)
|
||||
model_lower = model_name.lower()
|
||||
if "qwen3.5" in model_lower:
|
||||
return Qwen35VL(model_name, device, torch_dtype)
|
||||
if "qwen3" in model_lower:
|
||||
return Qwen3VL(model_name, device, torch_dtype)
|
||||
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
||||
return Qwen2VL(model_name, device, torch_dtype)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. "
|
||||
f"Supported models: {list(VLM_REGISTRY.keys())}. "
|
||||
"Or implement a new VLM class inheriting from BaseVLM."
|
||||
)
|
||||
"""Create a VLM instance. Defaults to QwenVL which supports the Qwen3.5 series."""
|
||||
return QwenVL(model_name, device, torch_dtype)
|
||||
|
||||
Reference in New Issue
Block a user