make it work

This commit is contained in:
Jade Choghari
2025-12-08 14:19:15 +00:00
parent 3568df8a35
commit 9091b68d86
+258 -166
View File
@@ -30,7 +30,6 @@ The pipeline:
Supported VLMs (modular design allows easy extension):
- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct"
- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct"
- SmolVLM: "HuggingFaceTB/SmolVLM-Instruct"
Usage:
```bash
@@ -52,7 +51,7 @@ After running, you can access the skill for any frame via:
dataset = LeRobotDataset(repo_id="your/dataset")
item = dataset[100]
task_idx = item["task_index"]
skill_name = dataset.meta.tasks.iloc[task_idx].name
skill_name = dataset.meta.tasks.iloc[task_idx.item()].name
```
"""
@@ -125,7 +124,7 @@ class BaseVLM(ABC):
To add a new VLM:
1. Create a subclass of BaseVLM
2. Implement the `__init__` and `segment_skills` methods
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
3. Register it in the VLM_REGISTRY dictionary
"""
@@ -151,6 +150,23 @@ class BaseVLM(ABC):
"""
pass
@abstractmethod
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""
Segment multiple videos into atomic skills in a single batch.
Args:
video_paths: List of paths to video files
episode_durations: List of episode durations in seconds
coarse_goal: Optional high-level task description
Returns:
List of skill lists, one for each video
"""
pass
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
"""Create the prompt for skill segmentation."""
@@ -258,6 +274,71 @@ class Qwen2VL(BaseVLM):
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
# Create messages for each video
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
},
],
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for response in responses:
try:
skills = self._parse_skills_response(response.strip())
all_skills.append(skills)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
# Extract JSON from response
@@ -349,6 +430,71 @@ class Qwen3VL(BaseVLM):
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
# Create messages for each video
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
},
],
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for response in responses:
try:
skills = self._parse_skills_response(response.strip())
all_skills.append(skills)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
@@ -371,137 +517,6 @@ class Qwen3VL(BaseVLM):
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# =============================================================================
# SmolVLM Implementation
# =============================================================================
class SmolVLM(BaseVLM):
"""SmolVLM model for skill segmentation (lighter weight alternative)."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from transformers import AutoModelForVision2Seq, AutoProcessor
self.console = Console()
self.device = device
self.model_name = model_name
self.console.print(f"[cyan]Loading SmolVLM model: {model_name}...[/cyan]")
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModelForVision2Seq.from_pretrained(
model_name,
torch_dtype=torch_dtype,
# _attn_implementation="flash_attention_2" if device == "cuda" else "eager",
).to(device)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
) -> list[Skill]:
"""Segment video into skills using SmolVLM with frame sampling."""
import PIL.Image
# SmolVLM works with images, so we sample frames from the video
frames = self._extract_frames(video_path, target_fps=1)
if not frames:
raise ValueError(f"Could not extract frames from {video_path}")
prompt = create_skill_segmentation_prompt(coarse_goal)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
# Sample frames (up to 8 frames to avoid context overflow)
frame_indices = self._select_frame_indices(len(frames), max_frames=8)
# Convert frames to PIL images
pil_images = [
PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB))
for idx in frame_indices
]
# Create message content with image placeholders
content = [{"type": "text", "text": prompt}]
# Add image placeholders (one for each frame)
for _ in frame_indices:
content.append({"type": "image"})
content.append(
{
"type": "text",
"text": f"These are {len(frame_indices)} sampled frames from a {duration_str} video. Segment into atomic skills.",
}
)
messages = [{"role": "user", "content": content}]
# Apply chat template to get the prompt
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
# Process inputs with both text and images
inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt")
inputs = inputs.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return self._parse_skills_response(response, episode_duration)
def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list:
"""Extract frames from video at target FPS."""
cap = cv2.VideoCapture(str(video_path))
frames = []
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_interval = int(fps / target_fps)
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frames.append(frame)
frame_count += 1
cap.release()
return frames
def _select_frame_indices(self, total_frames: int, max_frames: int = 8) -> list[int]:
"""Select evenly spaced frame indices."""
if total_frames <= max_frames:
return list(range(total_frames))
step = total_frames / max_frames
return [int(i * step) for i in range(max_frames)]
def _parse_skills_response(self, response: str, episode_duration: float) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
breakpoint()
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
# Fallback: create a single skill covering the whole episode
self.console.print("[yellow]Warning: Could not parse skills, creating single skill[/yellow]")
return [Skill(name="perform manipulation", start=0.0, end=episode_duration)]
# =============================================================================
# VLM Registry - Add new VLMs here
# =============================================================================
@@ -513,10 +528,6 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
# Qwen3-VL variants (MoE)
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
# SmolVLM variants
"HuggingFaceTB/SmolVLM-Instruct": SmolVLM,
"HuggingFaceTB/SmolVLM-256M-Instruct": SmolVLM,
"HuggingFaceTB/SmolVLM-500M-Instruct": SmolVLM,
}
@@ -545,8 +556,6 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
return Qwen3VL(model_name, device, torch_dtype)
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
return Qwen2VL(model_name, device, torch_dtype)
elif "smolvlm" in model_lower:
return SmolVLM(model_name, device, torch_dtype)
raise ValueError(
f"Unknown model: {model_name}. "
@@ -660,10 +669,12 @@ class SkillAnnotator:
vlm: BaseVLM,
video_extractor: VideoExtractor | None = None,
console: Console | None = None,
batch_size: int = 8,
):
self.vlm = vlm
self.console = console or Console()
self.video_extractor = video_extractor or VideoExtractor(self.console)
self.batch_size = batch_size
def annotate_dataset(
self,
@@ -673,7 +684,7 @@ class SkillAnnotator:
skip_existing: bool = False,
) -> dict[int, EpisodeSkills]:
"""
Annotate all episodes in a dataset with skill labels.
Annotate all episodes in a dataset with skill labels using batched processing.
Args:
dataset: LeRobot dataset to annotate
@@ -690,32 +701,45 @@ class SkillAnnotator:
# Get coarse task description if available
coarse_goal = self._get_coarse_goal(dataset)
# with Progress(
# SpinnerColumn(),
# TextColumn("[progress.description]{task.description}"),
# console=self.console,
# ) as progress:
# task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
print(f"Annotating {len(episode_indices)} episodes...")
print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...")
for ep_idx in episode_indices:
# progress.update(task, description=f"Processing episode {ep_idx}...")
print(f"Processing episode {ep_idx}...")
# Process episodes in batches
for batch_start in range(0, len(episode_indices), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(episode_indices))
batch_episodes = episode_indices[batch_start:batch_end]
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
batch_annotations = self._annotate_episodes_batch(
dataset, batch_episodes, video_key, coarse_goal
)
for ep_idx, skills in batch_annotations.items():
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
# progress.advance(task)
self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]")
# Fallback: process episodes one by one
for ep_idx in batch_episodes:
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
return annotations
@@ -730,6 +754,67 @@ class SkillAnnotator:
return "Perform the demonstrated manipulation task."
def _annotate_episodes_batch(
self,
dataset: LeRobotDataset,
episode_indices: list[int],
video_key: str,
coarse_goal: str,
) -> dict[int, list[Skill]]:
"""Annotate multiple episodes with skill labels in a batch."""
# Extract all videos for this batch
extracted_paths = []
durations = []
valid_episode_indices = []
for ep_idx in episode_indices:
try:
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]")
continue
# Get episode timestamps from metadata
ep = dataset.meta.episodes[ep_idx]
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = end_ts - start_ts
# Extract episode segment to temporary file
extracted_path = self.video_extractor.extract_episode_video(
video_path, start_ts, end_ts, target_fps=1
)
extracted_paths.append(extracted_path)
durations.append(duration)
valid_episode_indices.append(ep_idx)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]")
continue
if not extracted_paths:
return {}
try:
# Run VLM skill segmentation in batch
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
# Map results back to episode indices
results = {}
for ep_idx, skills in zip(valid_episode_indices, all_skills):
results[ep_idx] = skills
return results
finally:
# Clean up all temporary files
for path in extracted_paths:
if path.exists():
path.unlink()
def _annotate_episode(
self,
dataset: LeRobotDataset,
@@ -1007,15 +1092,15 @@ def load_skill_annotations(dataset_root: Path) -> dict | None:
def main():
"""Main entry point for the skill annotation script."""
parser = argparse.ArgumentParser(
description="Automatic skill annotation for LeRobot datasets using VLMs",
description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent("""\
Examples:
# Annotate a HuggingFace Hub dataset
python annotate.py --repo-id user/dataset --video-key observation.images.base
# Annotate a local dataset
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base
# Annotate a local dataset with custom batch size
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16
# Use a specific model
python annotate.py --repo-id user/dataset --video-key observation.images.base \\
@@ -1059,6 +1144,12 @@ def main():
choices=["bfloat16", "float16", "float32"],
help="Model dtype (default: bfloat16)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Number of episodes to process in each batch (default: 8)",
)
# Episode selection
parser.add_argument(
@@ -1116,7 +1207,8 @@ def main():
vlm = get_vlm(args.model, args.device, torch_dtype)
# Create annotator and run annotation
annotator = SkillAnnotator(vlm=vlm, console=console)
annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size)
console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]")
annotations = annotator.annotate_dataset(
dataset=dataset,
video_key=args.video_key,