mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
make it work
This commit is contained in:
+249
-157
@@ -30,7 +30,6 @@ The pipeline:
|
|||||||
Supported VLMs (modular design allows easy extension):
|
Supported VLMs (modular design allows easy extension):
|
||||||
- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct"
|
- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct"
|
||||||
- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||||
- SmolVLM: "HuggingFaceTB/SmolVLM-Instruct"
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```bash
|
```bash
|
||||||
@@ -52,7 +51,7 @@ After running, you can access the skill for any frame via:
|
|||||||
dataset = LeRobotDataset(repo_id="your/dataset")
|
dataset = LeRobotDataset(repo_id="your/dataset")
|
||||||
item = dataset[100]
|
item = dataset[100]
|
||||||
task_idx = item["task_index"]
|
task_idx = item["task_index"]
|
||||||
skill_name = dataset.meta.tasks.iloc[task_idx].name
|
skill_name = dataset.meta.tasks.iloc[task_idx.item()].name
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -125,7 +124,7 @@ class BaseVLM(ABC):
|
|||||||
|
|
||||||
To add a new VLM:
|
To add a new VLM:
|
||||||
1. Create a subclass of BaseVLM
|
1. Create a subclass of BaseVLM
|
||||||
2. Implement the `__init__` and `segment_skills` methods
|
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
|
||||||
3. Register it in the VLM_REGISTRY dictionary
|
3. Register it in the VLM_REGISTRY dictionary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -151,6 +150,23 @@ class BaseVLM(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def segment_skills_batch(
|
||||||
|
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||||
|
) -> list[list[Skill]]:
|
||||||
|
"""
|
||||||
|
Segment multiple videos into atomic skills in a single batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_paths: List of paths to video files
|
||||||
|
episode_durations: List of episode durations in seconds
|
||||||
|
coarse_goal: Optional high-level task description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of skill lists, one for each video
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
||||||
"""Create the prompt for skill segmentation."""
|
"""Create the prompt for skill segmentation."""
|
||||||
@@ -258,6 +274,71 @@ class Qwen2VL(BaseVLM):
|
|||||||
|
|
||||||
return self._parse_skills_response(response)
|
return self._parse_skills_response(response)
|
||||||
|
|
||||||
|
def segment_skills_batch(
|
||||||
|
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||||
|
) -> list[list[Skill]]:
|
||||||
|
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||||
|
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||||
|
|
||||||
|
# Create messages for each video
|
||||||
|
all_messages = []
|
||||||
|
for video_path, duration in zip(video_paths, episode_durations):
|
||||||
|
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
all_messages.append(messages)
|
||||||
|
|
||||||
|
# Process all videos in batch
|
||||||
|
all_texts = []
|
||||||
|
all_image_inputs = []
|
||||||
|
all_video_inputs = []
|
||||||
|
|
||||||
|
for messages in all_messages:
|
||||||
|
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||||
|
all_texts.append(text)
|
||||||
|
all_image_inputs.extend(image_inputs or [])
|
||||||
|
all_video_inputs.extend(video_inputs or [])
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=all_texts,
|
||||||
|
images=all_image_inputs if all_image_inputs else None,
|
||||||
|
videos=all_video_inputs if all_video_inputs else None,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||||
|
|
||||||
|
responses = self.processor.batch_decode(
|
||||||
|
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse each response
|
||||||
|
all_skills = []
|
||||||
|
for response in responses:
|
||||||
|
try:
|
||||||
|
skills = self._parse_skills_response(response.strip())
|
||||||
|
all_skills.append(skills)
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
|
||||||
|
all_skills.append([])
|
||||||
|
|
||||||
|
return all_skills
|
||||||
|
|
||||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||||
"""Parse the VLM response into Skill objects."""
|
"""Parse the VLM response into Skill objects."""
|
||||||
# Extract JSON from response
|
# Extract JSON from response
|
||||||
@@ -349,6 +430,71 @@ class Qwen3VL(BaseVLM):
|
|||||||
|
|
||||||
return self._parse_skills_response(response)
|
return self._parse_skills_response(response)
|
||||||
|
|
||||||
|
def segment_skills_batch(
|
||||||
|
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||||
|
) -> list[list[Skill]]:
|
||||||
|
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||||
|
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||||
|
|
||||||
|
# Create messages for each video
|
||||||
|
all_messages = []
|
||||||
|
for video_path, duration in zip(video_paths, episode_durations):
|
||||||
|
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video", "video": str(video_path), "fps": 1.0},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
all_messages.append(messages)
|
||||||
|
|
||||||
|
# Process all videos in batch
|
||||||
|
all_texts = []
|
||||||
|
all_image_inputs = []
|
||||||
|
all_video_inputs = []
|
||||||
|
|
||||||
|
for messages in all_messages:
|
||||||
|
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||||
|
all_texts.append(text)
|
||||||
|
all_image_inputs.extend(image_inputs or [])
|
||||||
|
all_video_inputs.extend(video_inputs or [])
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=all_texts,
|
||||||
|
images=all_image_inputs if all_image_inputs else None,
|
||||||
|
videos=all_video_inputs if all_video_inputs else None,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||||
|
|
||||||
|
responses = self.processor.batch_decode(
|
||||||
|
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse each response
|
||||||
|
all_skills = []
|
||||||
|
for response in responses:
|
||||||
|
try:
|
||||||
|
skills = self._parse_skills_response(response.strip())
|
||||||
|
all_skills.append(skills)
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
|
||||||
|
all_skills.append([])
|
||||||
|
|
||||||
|
return all_skills
|
||||||
|
|
||||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||||
"""Parse the VLM response into Skill objects."""
|
"""Parse the VLM response into Skill objects."""
|
||||||
if "```json" in response:
|
if "```json" in response:
|
||||||
@@ -371,137 +517,6 @@ class Qwen3VL(BaseVLM):
|
|||||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# SmolVLM Implementation
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class SmolVLM(BaseVLM):
|
|
||||||
"""SmolVLM model for skill segmentation (lighter weight alternative)."""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
|
||||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
||||||
|
|
||||||
self.console = Console()
|
|
||||||
self.device = device
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
self.console.print(f"[cyan]Loading SmolVLM model: {model_name}...[/cyan]")
|
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
# _attn_implementation="flash_attention_2" if device == "cuda" else "eager",
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
|
||||||
|
|
||||||
def segment_skills(
|
|
||||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
|
||||||
) -> list[Skill]:
|
|
||||||
"""Segment video into skills using SmolVLM with frame sampling."""
|
|
||||||
import PIL.Image
|
|
||||||
|
|
||||||
# SmolVLM works with images, so we sample frames from the video
|
|
||||||
frames = self._extract_frames(video_path, target_fps=1)
|
|
||||||
|
|
||||||
if not frames:
|
|
||||||
raise ValueError(f"Could not extract frames from {video_path}")
|
|
||||||
|
|
||||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
|
||||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
|
||||||
|
|
||||||
# Sample frames (up to 8 frames to avoid context overflow)
|
|
||||||
frame_indices = self._select_frame_indices(len(frames), max_frames=8)
|
|
||||||
|
|
||||||
# Convert frames to PIL images
|
|
||||||
pil_images = [
|
|
||||||
PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB))
|
|
||||||
for idx in frame_indices
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create message content with image placeholders
|
|
||||||
content = [{"type": "text", "text": prompt}]
|
|
||||||
|
|
||||||
# Add image placeholders (one for each frame)
|
|
||||||
for _ in frame_indices:
|
|
||||||
content.append({"type": "image"})
|
|
||||||
|
|
||||||
content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"These are {len(frame_indices)} sampled frames from a {duration_str} video. Segment into atomic skills.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": content}]
|
|
||||||
|
|
||||||
# Apply chat template to get the prompt
|
|
||||||
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
# Process inputs with both text and images
|
|
||||||
inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt")
|
|
||||||
inputs = inputs.to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
|
||||||
|
|
||||||
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
||||||
|
|
||||||
return self._parse_skills_response(response, episode_duration)
|
|
||||||
|
|
||||||
def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list:
|
|
||||||
"""Extract frames from video at target FPS."""
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
|
||||||
frames = []
|
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
|
||||||
frame_interval = int(fps / target_fps)
|
|
||||||
|
|
||||||
frame_count = 0
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
if frame_count % frame_interval == 0:
|
|
||||||
frames.append(frame)
|
|
||||||
frame_count += 1
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def _select_frame_indices(self, total_frames: int, max_frames: int = 8) -> list[int]:
|
|
||||||
"""Select evenly spaced frame indices."""
|
|
||||||
if total_frames <= max_frames:
|
|
||||||
return list(range(total_frames))
|
|
||||||
step = total_frames / max_frames
|
|
||||||
return [int(i * step) for i in range(max_frames)]
|
|
||||||
|
|
||||||
def _parse_skills_response(self, response: str, episode_duration: float) -> list[Skill]:
|
|
||||||
"""Parse the VLM response into Skill objects."""
|
|
||||||
if "```json" in response:
|
|
||||||
response = response.split("```json")[1].split("```")[0]
|
|
||||||
elif "```" in response:
|
|
||||||
response = response.split("```")[1].split("```")[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(response)
|
|
||||||
skills_data = data.get("skills", data)
|
|
||||||
breakpoint()
|
|
||||||
if isinstance(skills_data, list):
|
|
||||||
return [Skill.from_dict(s) for s in skills_data]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
skills_data = data.get("skills", [])
|
|
||||||
return [Skill.from_dict(s) for s in skills_data]
|
|
||||||
|
|
||||||
# Fallback: create a single skill covering the whole episode
|
|
||||||
self.console.print("[yellow]Warning: Could not parse skills, creating single skill[/yellow]")
|
|
||||||
return [Skill(name="perform manipulation", start=0.0, end=episode_duration)]
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# VLM Registry - Add new VLMs here
|
# VLM Registry - Add new VLMs here
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -513,10 +528,6 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
|||||||
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
|
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
|
||||||
# Qwen3-VL variants (MoE)
|
# Qwen3-VL variants (MoE)
|
||||||
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
||||||
# SmolVLM variants
|
|
||||||
"HuggingFaceTB/SmolVLM-Instruct": SmolVLM,
|
|
||||||
"HuggingFaceTB/SmolVLM-256M-Instruct": SmolVLM,
|
|
||||||
"HuggingFaceTB/SmolVLM-500M-Instruct": SmolVLM,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -545,8 +556,6 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
|
|||||||
return Qwen3VL(model_name, device, torch_dtype)
|
return Qwen3VL(model_name, device, torch_dtype)
|
||||||
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
||||||
return Qwen2VL(model_name, device, torch_dtype)
|
return Qwen2VL(model_name, device, torch_dtype)
|
||||||
elif "smolvlm" in model_lower:
|
|
||||||
return SmolVLM(model_name, device, torch_dtype)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model: {model_name}. "
|
f"Unknown model: {model_name}. "
|
||||||
@@ -660,10 +669,12 @@ class SkillAnnotator:
|
|||||||
vlm: BaseVLM,
|
vlm: BaseVLM,
|
||||||
video_extractor: VideoExtractor | None = None,
|
video_extractor: VideoExtractor | None = None,
|
||||||
console: Console | None = None,
|
console: Console | None = None,
|
||||||
|
batch_size: int = 8,
|
||||||
):
|
):
|
||||||
self.vlm = vlm
|
self.vlm = vlm
|
||||||
self.console = console or Console()
|
self.console = console or Console()
|
||||||
self.video_extractor = video_extractor or VideoExtractor(self.console)
|
self.video_extractor = video_extractor or VideoExtractor(self.console)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def annotate_dataset(
|
def annotate_dataset(
|
||||||
self,
|
self,
|
||||||
@@ -673,7 +684,7 @@ class SkillAnnotator:
|
|||||||
skip_existing: bool = False,
|
skip_existing: bool = False,
|
||||||
) -> dict[int, EpisodeSkills]:
|
) -> dict[int, EpisodeSkills]:
|
||||||
"""
|
"""
|
||||||
Annotate all episodes in a dataset with skill labels.
|
Annotate all episodes in a dataset with skill labels using batched processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: LeRobot dataset to annotate
|
dataset: LeRobot dataset to annotate
|
||||||
@@ -690,18 +701,33 @@ class SkillAnnotator:
|
|||||||
# Get coarse task description if available
|
# Get coarse task description if available
|
||||||
coarse_goal = self._get_coarse_goal(dataset)
|
coarse_goal = self._get_coarse_goal(dataset)
|
||||||
|
|
||||||
# with Progress(
|
print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...")
|
||||||
# SpinnerColumn(),
|
|
||||||
# TextColumn("[progress.description]{task.description}"),
|
|
||||||
# console=self.console,
|
|
||||||
# ) as progress:
|
|
||||||
# task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
|
|
||||||
print(f"Annotating {len(episode_indices)} episodes...")
|
|
||||||
|
|
||||||
for ep_idx in episode_indices:
|
# Process episodes in batches
|
||||||
# progress.update(task, description=f"Processing episode {ep_idx}...")
|
for batch_start in range(0, len(episode_indices), self.batch_size):
|
||||||
print(f"Processing episode {ep_idx}...")
|
batch_end = min(batch_start + self.batch_size, len(episode_indices))
|
||||||
|
batch_episodes = episode_indices[batch_start:batch_end]
|
||||||
|
|
||||||
|
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch_annotations = self._annotate_episodes_batch(
|
||||||
|
dataset, batch_episodes, video_key, coarse_goal
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx, skills in batch_annotations.items():
|
||||||
|
annotations[ep_idx] = EpisodeSkills(
|
||||||
|
episode_index=ep_idx,
|
||||||
|
description=coarse_goal,
|
||||||
|
skills=skills,
|
||||||
|
)
|
||||||
|
self.console.print(
|
||||||
|
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]")
|
||||||
|
# Fallback: process episodes one by one
|
||||||
|
for ep_idx in batch_episodes:
|
||||||
try:
|
try:
|
||||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||||
annotations[ep_idx] = EpisodeSkills(
|
annotations[ep_idx] = EpisodeSkills(
|
||||||
@@ -715,8 +741,6 @@ class SkillAnnotator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
||||||
|
|
||||||
# progress.advance(task)
|
|
||||||
|
|
||||||
return annotations
|
return annotations
|
||||||
|
|
||||||
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
|
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
|
||||||
@@ -730,6 +754,67 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
return "Perform the demonstrated manipulation task."
|
return "Perform the demonstrated manipulation task."
|
||||||
|
|
||||||
|
def _annotate_episodes_batch(
|
||||||
|
self,
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
episode_indices: list[int],
|
||||||
|
video_key: str,
|
||||||
|
coarse_goal: str,
|
||||||
|
) -> dict[int, list[Skill]]:
|
||||||
|
"""Annotate multiple episodes with skill labels in a batch."""
|
||||||
|
# Extract all videos for this batch
|
||||||
|
extracted_paths = []
|
||||||
|
durations = []
|
||||||
|
valid_episode_indices = []
|
||||||
|
|
||||||
|
for ep_idx in episode_indices:
|
||||||
|
try:
|
||||||
|
# Get video path and timestamps
|
||||||
|
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||||
|
|
||||||
|
if not video_path.exists():
|
||||||
|
self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get episode timestamps from metadata
|
||||||
|
ep = dataset.meta.episodes[ep_idx]
|
||||||
|
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
|
||||||
|
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
|
||||||
|
duration = end_ts - start_ts
|
||||||
|
|
||||||
|
# Extract episode segment to temporary file
|
||||||
|
extracted_path = self.video_extractor.extract_episode_video(
|
||||||
|
video_path, start_ts, end_ts, target_fps=1
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_paths.append(extracted_path)
|
||||||
|
durations.append(duration)
|
||||||
|
valid_episode_indices.append(ep_idx)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not extracted_paths:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run VLM skill segmentation in batch
|
||||||
|
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
|
||||||
|
|
||||||
|
# Map results back to episode indices
|
||||||
|
results = {}
|
||||||
|
for ep_idx, skills in zip(valid_episode_indices, all_skills):
|
||||||
|
results[ep_idx] = skills
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up all temporary files
|
||||||
|
for path in extracted_paths:
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
|
||||||
def _annotate_episode(
|
def _annotate_episode(
|
||||||
self,
|
self,
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
@@ -1007,15 +1092,15 @@ def load_skill_annotations(dataset_root: Path) -> dict | None:
|
|||||||
def main():
|
def main():
|
||||||
"""Main entry point for the skill annotation script."""
|
"""Main entry point for the skill annotation script."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Automatic skill annotation for LeRobot datasets using VLMs",
|
description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog=textwrap.dedent("""\
|
epilog=textwrap.dedent("""\
|
||||||
Examples:
|
Examples:
|
||||||
# Annotate a HuggingFace Hub dataset
|
# Annotate a HuggingFace Hub dataset
|
||||||
python annotate.py --repo-id user/dataset --video-key observation.images.base
|
python annotate.py --repo-id user/dataset --video-key observation.images.base
|
||||||
|
|
||||||
# Annotate a local dataset
|
# Annotate a local dataset with custom batch size
|
||||||
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base
|
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16
|
||||||
|
|
||||||
# Use a specific model
|
# Use a specific model
|
||||||
python annotate.py --repo-id user/dataset --video-key observation.images.base \\
|
python annotate.py --repo-id user/dataset --video-key observation.images.base \\
|
||||||
@@ -1059,6 +1144,12 @@ def main():
|
|||||||
choices=["bfloat16", "float16", "float32"],
|
choices=["bfloat16", "float16", "float32"],
|
||||||
help="Model dtype (default: bfloat16)",
|
help="Model dtype (default: bfloat16)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of episodes to process in each batch (default: 8)",
|
||||||
|
)
|
||||||
|
|
||||||
# Episode selection
|
# Episode selection
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1116,7 +1207,8 @@ def main():
|
|||||||
vlm = get_vlm(args.model, args.device, torch_dtype)
|
vlm = get_vlm(args.model, args.device, torch_dtype)
|
||||||
|
|
||||||
# Create annotator and run annotation
|
# Create annotator and run annotation
|
||||||
annotator = SkillAnnotator(vlm=vlm, console=console)
|
annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size)
|
||||||
|
console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]")
|
||||||
annotations = annotator.annotate_dataset(
|
annotations = annotator.annotate_dataset(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
video_key=args.video_key,
|
video_key=args.video_key,
|
||||||
|
|||||||
Reference in New Issue
Block a user