mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
make it work
This commit is contained in:
+258
-166
@@ -30,7 +30,6 @@ The pipeline:
|
||||
Supported VLMs (modular design allows easy extension):
|
||||
- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct"
|
||||
- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
- SmolVLM: "HuggingFaceTB/SmolVLM-Instruct"
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
@@ -52,7 +51,7 @@ After running, you can access the skill for any frame via:
|
||||
dataset = LeRobotDataset(repo_id="your/dataset")
|
||||
item = dataset[100]
|
||||
task_idx = item["task_index"]
|
||||
skill_name = dataset.meta.tasks.iloc[task_idx].name
|
||||
skill_name = dataset.meta.tasks.iloc[task_idx.item()].name
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -125,7 +124,7 @@ class BaseVLM(ABC):
|
||||
|
||||
To add a new VLM:
|
||||
1. Create a subclass of BaseVLM
|
||||
2. Implement the `__init__` and `segment_skills` methods
|
||||
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
|
||||
3. Register it in the VLM_REGISTRY dictionary
|
||||
"""
|
||||
|
||||
@@ -151,6 +150,23 @@ class BaseVLM(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
) -> list[list[Skill]]:
|
||||
"""
|
||||
Segment multiple videos into atomic skills in a single batch.
|
||||
|
||||
Args:
|
||||
video_paths: List of paths to video files
|
||||
episode_durations: List of episode durations in seconds
|
||||
coarse_goal: Optional high-level task description
|
||||
|
||||
Returns:
|
||||
List of skill lists, one for each video
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
||||
"""Create the prompt for skill segmentation."""
|
||||
@@ -258,6 +274,71 @@ class Qwen2VL(BaseVLM):
|
||||
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
|
||||
# Create messages for each video
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations):
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for response in responses:
|
||||
try:
|
||||
skills = self._parse_skills_response(response.strip())
|
||||
all_skills.append(skills)
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
|
||||
all_skills.append([])
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
# Extract JSON from response
|
||||
@@ -349,6 +430,71 @@ class Qwen3VL(BaseVLM):
|
||||
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
|
||||
# Create messages for each video
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations):
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for response in responses:
|
||||
try:
|
||||
skills = self._parse_skills_response(response.strip())
|
||||
all_skills.append(skills)
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
|
||||
all_skills.append([])
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
if "```json" in response:
|
||||
@@ -371,137 +517,6 @@ class Qwen3VL(BaseVLM):
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SmolVLM Implementation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SmolVLM(BaseVLM):
|
||||
"""SmolVLM model for skill segmentation (lighter weight alternative)."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.model_name = model_name
|
||||
|
||||
self.console.print(f"[cyan]Loading SmolVLM model: {model_name}...[/cyan]")
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
# _attn_implementation="flash_attention_2" if device == "cuda" else "eager",
|
||||
).to(device)
|
||||
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
def segment_skills(
|
||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using SmolVLM with frame sampling."""
|
||||
import PIL.Image
|
||||
|
||||
# SmolVLM works with images, so we sample frames from the video
|
||||
frames = self._extract_frames(video_path, target_fps=1)
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"Could not extract frames from {video_path}")
|
||||
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
|
||||
# Sample frames (up to 8 frames to avoid context overflow)
|
||||
frame_indices = self._select_frame_indices(len(frames), max_frames=8)
|
||||
|
||||
# Convert frames to PIL images
|
||||
pil_images = [
|
||||
PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB))
|
||||
for idx in frame_indices
|
||||
]
|
||||
|
||||
# Create message content with image placeholders
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# Add image placeholders (one for each frame)
|
||||
for _ in frame_indices:
|
||||
content.append({"type": "image"})
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"These are {len(frame_indices)} sampled frames from a {duration_str} video. Segment into atomic skills.",
|
||||
}
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
# Apply chat template to get the prompt
|
||||
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Process inputs with both text and images
|
||||
inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt")
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
|
||||
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
|
||||
return self._parse_skills_response(response, episode_duration)
|
||||
|
||||
def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list:
|
||||
"""Extract frames from video at target FPS."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
frames = []
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
frame_interval = int(fps / target_fps)
|
||||
|
||||
frame_count = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if frame_count % frame_interval == 0:
|
||||
frames.append(frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
return frames
|
||||
|
||||
def _select_frame_indices(self, total_frames: int, max_frames: int = 8) -> list[int]:
|
||||
"""Select evenly spaced frame indices."""
|
||||
if total_frames <= max_frames:
|
||||
return list(range(total_frames))
|
||||
step = total_frames / max_frames
|
||||
return [int(i * step) for i in range(max_frames)]
|
||||
|
||||
def _parse_skills_response(self, response: str, episode_duration: float) -> list[Skill]:
|
||||
"""Parse the VLM response into Skill objects."""
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
skills_data = data.get("skills", data)
|
||||
breakpoint()
|
||||
if isinstance(skills_data, list):
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
|
||||
# Fallback: create a single skill covering the whole episode
|
||||
self.console.print("[yellow]Warning: Could not parse skills, creating single skill[/yellow]")
|
||||
return [Skill(name="perform manipulation", start=0.0, end=episode_duration)]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VLM Registry - Add new VLMs here
|
||||
# =============================================================================
|
||||
@@ -513,10 +528,6 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
||||
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
|
||||
# Qwen3-VL variants (MoE)
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
||||
# SmolVLM variants
|
||||
"HuggingFaceTB/SmolVLM-Instruct": SmolVLM,
|
||||
"HuggingFaceTB/SmolVLM-256M-Instruct": SmolVLM,
|
||||
"HuggingFaceTB/SmolVLM-500M-Instruct": SmolVLM,
|
||||
}
|
||||
|
||||
|
||||
@@ -545,8 +556,6 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
|
||||
return Qwen3VL(model_name, device, torch_dtype)
|
||||
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
||||
return Qwen2VL(model_name, device, torch_dtype)
|
||||
elif "smolvlm" in model_lower:
|
||||
return SmolVLM(model_name, device, torch_dtype)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. "
|
||||
@@ -660,10 +669,12 @@ class SkillAnnotator:
|
||||
vlm: BaseVLM,
|
||||
video_extractor: VideoExtractor | None = None,
|
||||
console: Console | None = None,
|
||||
batch_size: int = 8,
|
||||
):
|
||||
self.vlm = vlm
|
||||
self.console = console or Console()
|
||||
self.video_extractor = video_extractor or VideoExtractor(self.console)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def annotate_dataset(
|
||||
self,
|
||||
@@ -673,7 +684,7 @@ class SkillAnnotator:
|
||||
skip_existing: bool = False,
|
||||
) -> dict[int, EpisodeSkills]:
|
||||
"""
|
||||
Annotate all episodes in a dataset with skill labels.
|
||||
Annotate all episodes in a dataset with skill labels using batched processing.
|
||||
|
||||
Args:
|
||||
dataset: LeRobot dataset to annotate
|
||||
@@ -690,32 +701,45 @@ class SkillAnnotator:
|
||||
# Get coarse task description if available
|
||||
coarse_goal = self._get_coarse_goal(dataset)
|
||||
|
||||
# with Progress(
|
||||
# SpinnerColumn(),
|
||||
# TextColumn("[progress.description]{task.description}"),
|
||||
# console=self.console,
|
||||
# ) as progress:
|
||||
# task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
|
||||
print(f"Annotating {len(episode_indices)} episodes...")
|
||||
print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...")
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# progress.update(task, description=f"Processing episode {ep_idx}...")
|
||||
print(f"Processing episode {ep_idx}...")
|
||||
# Process episodes in batches
|
||||
for batch_start in range(0, len(episode_indices), self.batch_size):
|
||||
batch_end = min(batch_start + self.batch_size, len(episode_indices))
|
||||
batch_episodes = episode_indices[batch_start:batch_end]
|
||||
|
||||
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
|
||||
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(
|
||||
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||
batch_annotations = self._annotate_episodes_batch(
|
||||
dataset, batch_episodes, video_key, coarse_goal
|
||||
)
|
||||
|
||||
for ep_idx, skills in batch_annotations.items():
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(
|
||||
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
||||
|
||||
# progress.advance(task)
|
||||
self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]")
|
||||
# Fallback: process episodes one by one
|
||||
for ep_idx in batch_episodes:
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(
|
||||
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
||||
|
||||
return annotations
|
||||
|
||||
@@ -730,6 +754,67 @@ class SkillAnnotator:
|
||||
|
||||
return "Perform the demonstrated manipulation task."
|
||||
|
||||
def _annotate_episodes_batch(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
episode_indices: list[int],
|
||||
video_key: str,
|
||||
coarse_goal: str,
|
||||
) -> dict[int, list[Skill]]:
|
||||
"""Annotate multiple episodes with skill labels in a batch."""
|
||||
# Extract all videos for this batch
|
||||
extracted_paths = []
|
||||
durations = []
|
||||
valid_episode_indices = []
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
try:
|
||||
# Get video path and timestamps
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
|
||||
if not video_path.exists():
|
||||
self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]")
|
||||
continue
|
||||
|
||||
# Get episode timestamps from metadata
|
||||
ep = dataset.meta.episodes[ep_idx]
|
||||
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
|
||||
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
|
||||
duration = end_ts - start_ts
|
||||
|
||||
# Extract episode segment to temporary file
|
||||
extracted_path = self.video_extractor.extract_episode_video(
|
||||
video_path, start_ts, end_ts, target_fps=1
|
||||
)
|
||||
|
||||
extracted_paths.append(extracted_path)
|
||||
durations.append(duration)
|
||||
valid_episode_indices.append(ep_idx)
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]")
|
||||
continue
|
||||
|
||||
if not extracted_paths:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# Run VLM skill segmentation in batch
|
||||
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
|
||||
|
||||
# Map results back to episode indices
|
||||
results = {}
|
||||
for ep_idx, skills in zip(valid_episode_indices, all_skills):
|
||||
results[ep_idx] = skills
|
||||
|
||||
return results
|
||||
|
||||
finally:
|
||||
# Clean up all temporary files
|
||||
for path in extracted_paths:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
def _annotate_episode(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
@@ -1007,15 +1092,15 @@ def load_skill_annotations(dataset_root: Path) -> dict | None:
|
||||
def main():
|
||||
"""Main entry point for the skill annotation script."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Automatic skill annotation for LeRobot datasets using VLMs",
|
||||
description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=textwrap.dedent("""\
|
||||
Examples:
|
||||
# Annotate a HuggingFace Hub dataset
|
||||
python annotate.py --repo-id user/dataset --video-key observation.images.base
|
||||
|
||||
# Annotate a local dataset
|
||||
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base
|
||||
# Annotate a local dataset with custom batch size
|
||||
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16
|
||||
|
||||
# Use a specific model
|
||||
python annotate.py --repo-id user/dataset --video-key observation.images.base \\
|
||||
@@ -1059,6 +1144,12 @@ def main():
|
||||
choices=["bfloat16", "float16", "float32"],
|
||||
help="Model dtype (default: bfloat16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of episodes to process in each batch (default: 8)",
|
||||
)
|
||||
|
||||
# Episode selection
|
||||
parser.add_argument(
|
||||
@@ -1116,7 +1207,8 @@ def main():
|
||||
vlm = get_vlm(args.model, args.device, torch_dtype)
|
||||
|
||||
# Create annotator and run annotation
|
||||
annotator = SkillAnnotator(vlm=vlm, console=console)
|
||||
annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size)
|
||||
console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]")
|
||||
annotations = annotator.annotate_dataset(
|
||||
dataset=dataset,
|
||||
video_key=args.video_key,
|
||||
|
||||
Reference in New Issue
Block a user