Files
lerobot/examples/dataset/annotate.py
T
Jade Choghari 9091b68d86 make it work
2025-12-08 14:19:15 +00:00

1245 lines
44 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Automatic Skill Annotation for LeRobot Datasets.
This script performs automatic subtask/skill labeling for ANY LeRobot dataset using
Vision-Language Models (VLMs). It segments each robot demonstration into short atomic
skills (1-3 seconds each) and updates the dataset's task field.
The pipeline:
1. Loads a LeRobot dataset (local or from HuggingFace Hub)
2. For each episode, extracts video frames
3. Uses a VLM to identify skill boundaries and labels
4. Updates the dataset's task metadata with skill annotations
Supported VLMs (modular design allows easy extension):
- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct"
- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct"
Usage:
```bash
python examples/dataset/annotate.py \
--repo-id your-username/your-dataset \
--video-key observation.images.base \
--model Qwen/Qwen2-VL-7B-Instruct \
--push-to-hub
```
Or with a local dataset:
```bash
python examples/dataset/annotate.py \
--data-dir /path/to/local/dataset \
--video-key observation.images.base
```
After running, you can access the skill for any frame via:
```python
dataset = LeRobotDataset(repo_id="your/dataset")
item = dataset[100]
task_idx = item["task_index"]
skill_name = dataset.meta.tasks.iloc[task_idx.item()].name
```
"""
import argparse
import json
import re
import subprocess
import tempfile
import textwrap
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import cv2
import torch
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# =============================================================================
# Skill Annotation Data Structures
# =============================================================================
class Skill:
"""Represents a single atomic skill/subtask in a demonstration."""
def __init__(self, name: str, start: float, end: float):
self.name = name
self.start = start # Start timestamp in seconds
self.end = end # End timestamp in seconds
def to_dict(self) -> dict:
return {"name": self.name, "start": self.start, "end": self.end}
@classmethod
def from_dict(cls, data: dict) -> "Skill":
return cls(name=data["name"], start=data["start"], end=data["end"])
def __repr__(self) -> str:
return f"Skill(name='{self.name}', start={self.start:.2f}, end={self.end:.2f})"
class EpisodeSkills:
"""Container for all skills in an episode."""
def __init__(self, episode_index: int, description: str, skills: list[Skill]):
self.episode_index = episode_index
self.description = description
self.skills = skills
def to_dict(self) -> dict:
return {
"episode_index": self.episode_index,
"description": self.description,
"skills": [s.to_dict() for s in self.skills],
}
# =============================================================================
# VLM Interface (Abstract Base Class for Modularity)
# =============================================================================
class BaseVLM(ABC):
"""
Abstract base class for Vision-Language Models.
To add a new VLM:
1. Create a subclass of BaseVLM
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
3. Register it in the VLM_REGISTRY dictionary
"""
@abstractmethod
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
"""Initialize the VLM with model name, device, and dtype."""
pass
@abstractmethod
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
) -> list[Skill]:
"""
Segment a video into atomic skills.
Args:
video_path: Path to the video file
episode_duration: Total duration of the episode in seconds
coarse_goal: Optional high-level task description
Returns:
List of Skill objects representing atomic manipulation skills
"""
pass
@abstractmethod
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""
Segment multiple videos into atomic skills in a single batch.
Args:
video_paths: List of paths to video files
episode_durations: List of episode durations in seconds
coarse_goal: Optional high-level task description
Returns:
List of skill lists, one for each video
"""
pass
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
"""Create the prompt for skill segmentation."""
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
return textwrap.dedent(f"""\
# Role
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
# Task
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
- Last approximately 1-3 seconds
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
- Have precise start and end timestamps
# Requirements
1. **Atomic Actions**: Each skill should be a single, indivisible action
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
3. **Boundary Consistency**: The end of one skill equals the start of the next
4. **Natural Language**: Use clear, descriptive names for each skill
5. **Timestamps**: Use seconds (float) for all timestamps
# Output Format
After your analysis, output ONLY valid JSON with this exact structure:
```json
{{
"skills": [
{{"name": "skill description", "start": 0.0, "end": 1.5}},
{{"name": "another skill", "start": 1.5, "end": 3.2}}
]
}}
```
The first skill must start at 0.0 and the last skill must end at the video duration.
""")
# =============================================================================
# Qwen2-VL Implementation
# =============================================================================
class Qwen2VL(BaseVLM):
"""Qwen2-VL model for skill segmentation."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
self.console = Console()
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
self.console.print(f"[cyan]Loading Qwen2-VL model: {model_name}...[/cyan]")
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
) -> list[Skill]:
"""Segment video into skills using Qwen2-VL."""
prompt = create_skill_segmentation_prompt(coarse_goal)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
# Create messages for each video
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
},
],
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for response in responses:
try:
skills = self._parse_skills_response(response.strip())
all_skills.append(skills)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
# Try to find JSON object in response
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# =============================================================================
# Qwen3-VL Implementation (MoE variant)
# =============================================================================
class Qwen3VL(BaseVLM):
"""Qwen3-VL MoE model for skill segmentation."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
self.console = Console()
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
self.console.print(f"[cyan]Loading Qwen3-VL model: {model_name}...[/cyan]")
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
) -> list[Skill]:
"""Segment video into skills using Qwen3-VL."""
prompt = create_skill_segmentation_prompt(coarse_goal)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
# Create messages for each video
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.",
},
],
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for response in responses:
try:
skills = self._parse_skills_response(response.strip())
all_skills.append(skills)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# =============================================================================
# VLM Registry - Add new VLMs here
# =============================================================================
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
# Qwen2-VL variants
"Qwen/Qwen2-VL-2B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-7B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
# Qwen3-VL variants (MoE)
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
}
def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM:
"""
Factory function to get the appropriate VLM based on model name.
Args:
model_name: HuggingFace model identifier
device: Device to load model on
torch_dtype: Data type for model weights
Returns:
Initialized VLM instance
Raises:
ValueError: If model is not in registry
"""
# Check exact match first
if model_name in VLM_REGISTRY:
return VLM_REGISTRY[model_name](model_name, device, torch_dtype)
# Check for partial matches (e.g., "qwen2" in model name)
model_lower = model_name.lower()
if "qwen3" in model_lower:
return Qwen3VL(model_name, device, torch_dtype)
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
return Qwen2VL(model_name, device, torch_dtype)
raise ValueError(
f"Unknown model: {model_name}. "
f"Supported models: {list(VLM_REGISTRY.keys())}. "
"Or implement a new VLM class inheriting from BaseVLM."
)
# =============================================================================
# Video Extraction Utilities
# =============================================================================
class VideoExtractor:
"""Utilities for extracting and processing video segments from LeRobot datasets."""
def __init__(self, console: Console | None = None):
self.console = console or Console()
def extract_episode_video(
self,
video_path: Path,
start_timestamp: float,
end_timestamp: float,
target_fps: int = 1,
) -> Path:
"""
Extract a specific episode segment from a concatenated video file.
Args:
video_path: Path to the source video file
start_timestamp: Start time in seconds
end_timestamp: End time in seconds
target_fps: Target frames per second for output
Returns:
Path to the extracted temporary video file
"""
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp_path = Path(tmp_file.name)
tmp_file.close()
duration = end_timestamp - start_timestamp
self.console.print(
f"[cyan]Extracting: {start_timestamp:.1f}s - {end_timestamp:.1f}s ({duration:.1f}s)[/cyan]"
)
cmd = [
"ffmpeg",
"-i",
str(video_path),
"-ss",
str(start_timestamp),
"-t",
str(duration),
"-r",
str(target_fps),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-an",
"-y",
str(tmp_path),
]
try:
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"FFmpeg failed: {e}") from e
except FileNotFoundError:
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("Video extraction produced invalid file")
return tmp_path
def get_video_duration(self, video_path: Path) -> float:
"""Get duration of a video file in seconds."""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return frame_count / fps
# =============================================================================
# Skill Annotation Pipeline
# =============================================================================
class SkillAnnotator:
"""
Main class for annotating LeRobot datasets with skill labels.
This class orchestrates the full annotation pipeline:
1. Load dataset
2. Extract video segments for each episode
3. Run VLM-based skill segmentation
4. Update dataset task metadata
"""
def __init__(
self,
vlm: BaseVLM,
video_extractor: VideoExtractor | None = None,
console: Console | None = None,
batch_size: int = 8,
):
self.vlm = vlm
self.console = console or Console()
self.video_extractor = video_extractor or VideoExtractor(self.console)
self.batch_size = batch_size
def annotate_dataset(
self,
dataset: LeRobotDataset,
video_key: str,
episodes: list[int] | None = None,
skip_existing: bool = False,
) -> dict[int, EpisodeSkills]:
"""
Annotate all episodes in a dataset with skill labels using batched processing.
Args:
dataset: LeRobot dataset to annotate
video_key: Key for video observations (e.g., "observation.images.base")
episodes: Specific episode indices to annotate (None = all)
skip_existing: Skip episodes that already have skill annotations
Returns:
Dictionary mapping episode index to EpisodeSkills
"""
episode_indices = episodes or list(range(dataset.meta.total_episodes))
annotations: dict[int, EpisodeSkills] = {}
# Get coarse task description if available
coarse_goal = self._get_coarse_goal(dataset)
print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...")
# Process episodes in batches
for batch_start in range(0, len(episode_indices), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(episode_indices))
batch_episodes = episode_indices[batch_start:batch_end]
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
try:
batch_annotations = self._annotate_episodes_batch(
dataset, batch_episodes, video_key, coarse_goal
)
for ep_idx, skills in batch_annotations.items():
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]")
# Fallback: process episodes one by one
for ep_idx in batch_episodes:
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
return annotations
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
"""Extract or generate the coarse task description."""
# Try to get from existing task metadata
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
# Get the first task description
first_task = dataset.meta.tasks.index[0]
if first_task:
return str(first_task)
return "Perform the demonstrated manipulation task."
def _annotate_episodes_batch(
self,
dataset: LeRobotDataset,
episode_indices: list[int],
video_key: str,
coarse_goal: str,
) -> dict[int, list[Skill]]:
"""Annotate multiple episodes with skill labels in a batch."""
# Extract all videos for this batch
extracted_paths = []
durations = []
valid_episode_indices = []
for ep_idx in episode_indices:
try:
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]")
continue
# Get episode timestamps from metadata
ep = dataset.meta.episodes[ep_idx]
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = end_ts - start_ts
# Extract episode segment to temporary file
extracted_path = self.video_extractor.extract_episode_video(
video_path, start_ts, end_ts, target_fps=1
)
extracted_paths.append(extracted_path)
durations.append(duration)
valid_episode_indices.append(ep_idx)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]")
continue
if not extracted_paths:
return {}
try:
# Run VLM skill segmentation in batch
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
# Map results back to episode indices
results = {}
for ep_idx, skills in zip(valid_episode_indices, all_skills):
results[ep_idx] = skills
return results
finally:
# Clean up all temporary files
for path in extracted_paths:
if path.exists():
path.unlink()
def _annotate_episode(
self,
dataset: LeRobotDataset,
episode_index: int,
video_key: str,
coarse_goal: str,
) -> list[Skill]:
"""Annotate a single episode with skill labels."""
# Get video path and timestamps for this episode
video_path = dataset.root / dataset.meta.get_video_file_path(episode_index, video_key)
if not video_path.exists():
raise FileNotFoundError(f"Video not found: {video_path}")
# Get episode timestamps from metadata
ep = dataset.meta.episodes[episode_index]
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = end_ts - start_ts
# Extract episode segment to temporary file
extracted_path = self.video_extractor.extract_episode_video(
video_path, start_ts, end_ts, target_fps=1
)
try:
# Run VLM skill segmentation
skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal)
return skills
finally:
# Clean up temporary file
if extracted_path.exists():
extracted_path.unlink()
# =============================================================================
# Metadata Writer - Updates per-frame task_index based on skills
# =============================================================================
def get_skill_for_timestamp(skills: list[Skill], timestamp: float) -> Skill | None:
"""
Find which skill covers a given timestamp.
Args:
skills: List of skills with start/end times
timestamp: Frame timestamp in seconds
Returns:
The Skill that covers this timestamp, or None if not found
"""
for skill in skills:
if skill.start <= timestamp < skill.end:
return skill
# Handle the last frame (end boundary)
if timestamp >= skill.end and skill == skills[-1]:
return skill
return skills[-1] if skills else None # Fallback to last skill
def update_dataset_tasks(
dataset: LeRobotDataset,
annotations: dict[int, EpisodeSkills],
) -> dict[str, int]:
"""
Register all unique skill names as new tasks in the dataset.
Args:
dataset: The LeRobot dataset to update
annotations: Dictionary of episode skills
Returns:
Dictionary mapping skill name to task_index
"""
import pandas as pd
from lerobot.datasets.utils import write_tasks
console = Console()
# Collect all unique skill names
all_skill_names: set[str] = set()
for episode_skills in annotations.values():
for skill in episode_skills.skills:
all_skill_names.add(skill.name)
console.print(f"[cyan]Found {len(all_skill_names)} unique skills[/cyan]")
# Build new tasks DataFrame
# Start with existing tasks if any
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
existing_tasks = set(dataset.meta.tasks.index.tolist())
max_task_idx = dataset.meta.tasks["task_index"].max()
else:
existing_tasks = set()
max_task_idx = -1
# Add new skills as tasks
new_tasks = all_skill_names - existing_tasks
if new_tasks:
new_task_data = []
for i, skill_name in enumerate(sorted(new_tasks)):
new_task_data.append({
"task": skill_name,
"task_index": max_task_idx + 1 + i,
})
new_tasks_df = pd.DataFrame(new_task_data).set_index("task")
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
dataset.meta.tasks = pd.concat([dataset.meta.tasks, new_tasks_df])
else:
dataset.meta.tasks = new_tasks_df
# Write updated tasks to disk
write_tasks(dataset.meta.tasks, dataset.root)
console.print(f"[green]✓ Added {len(new_tasks)} new tasks to tasks.parquet[/green]")
# Build skill name to task_index mapping
skill_to_task_idx = {
task_name: int(dataset.meta.tasks.loc[task_name, "task_index"])
for task_name in all_skill_names
}
return skill_to_task_idx
def update_frame_task_indices(
dataset: LeRobotDataset,
annotations: dict[int, EpisodeSkills],
skill_to_task_idx: dict[str, int],
) -> None:
"""
Update the task_index for each frame based on skill annotations.
This reads the data parquet files, updates task_index based on which
skill covers each frame's timestamp, and writes back to disk.
Args:
dataset: The LeRobot dataset to update
annotations: Dictionary of episode skills
skill_to_task_idx: Mapping from skill name to task_index
"""
import pandas as pd
console = Console()
# Group episodes by their data file (chunk_index, file_index)
episodes_by_file: dict[tuple[int, int], list[int]] = {}
for ep_idx in annotations.keys():
ep = dataset.meta.episodes[ep_idx]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
key = (chunk_idx, file_idx)
if key not in episodes_by_file:
episodes_by_file[key] = []
episodes_by_file[key].append(ep_idx)
# Process each data file
for (chunk_idx, file_idx), episode_indices in episodes_by_file.items():
data_path = dataset.root / dataset.meta.data_path.format(
chunk_index=chunk_idx, file_index=file_idx
)
if not data_path.exists():
console.print(f"[yellow]Warning: Data file not found: {data_path}[/yellow]")
continue
# Read the parquet file
df = pd.read_parquet(data_path)
original_task_indices = df["task_index"].copy()
updated_count = 0
# Update task_index for each episode in this file
for ep_idx in episode_indices:
if ep_idx not in annotations:
continue
episode_skills = annotations[ep_idx]
skills = episode_skills.skills
# Get episode frame range
ep = dataset.meta.episodes[ep_idx]
ep_from = ep["dataset_from_index"]
ep_to = ep["dataset_to_index"]
# Filter to rows for this episode
episode_mask = (df["index"] >= ep_from) & (df["index"] < ep_to)
episode_rows = df.loc[episode_mask]
# Update task_index for each frame based on its timestamp
for idx, row in episode_rows.iterrows():
timestamp = row["timestamp"]
skill = get_skill_for_timestamp(skills, timestamp)
if skill and skill.name in skill_to_task_idx:
new_task_idx = skill_to_task_idx[skill.name]
if df.at[idx, "task_index"] != new_task_idx:
df.at[idx, "task_index"] = new_task_idx
updated_count += 1
# Write back if any changes were made
if updated_count > 0:
df.to_parquet(data_path, engine="pyarrow", compression="snappy", index=False)
console.print(
f"[green]✓ Updated {updated_count} frame task_indices in {data_path.name}[/green]"
)
def save_skill_annotations(
dataset: LeRobotDataset,
annotations: dict[int, EpisodeSkills],
output_path: Path | None = None,
) -> None:
"""
Save skill annotations to the dataset, updating both:
1. The tasks.parquet with new skill names
2. The per-frame task_index in data parquet files
This function updates the task field for each frame based on
which skill covers that frame's timestamp.
Args:
dataset: The LeRobot dataset to update
annotations: Dictionary of episode skills
output_path: Optional custom output path for the annotations JSON
"""
console = Console()
if not annotations:
console.print("[yellow]No annotations to save[/yellow]")
return
# Step 1: Register all unique skills as tasks
console.print("[cyan]Registering skills as tasks...[/cyan]")
skill_to_task_idx = update_dataset_tasks(dataset, annotations)
# Step 2: Update per-frame task_index in data parquet files
console.print("[cyan]Updating per-frame task indices...[/cyan]")
update_frame_task_indices(dataset, annotations, skill_to_task_idx)
# Step 3: Also save the raw skill annotations as JSON for reference
skills_data = {
"coarse_description": annotations[next(iter(annotations))].description,
"skill_to_task_index": skill_to_task_idx,
"episodes": {str(ep_idx): ann.to_dict() for ep_idx, ann in annotations.items()},
}
skills_path = output_path or (dataset.root / "meta" / "skills.json")
skills_path.parent.mkdir(parents=True, exist_ok=True)
with open(skills_path, "w") as f:
json.dump(skills_data, f, indent=2)
console.print(f"[green]✓ Saved skill annotations to {skills_path}[/green]")
# Reload the dataset's hf_dataset to reflect changes
dataset._lazy_loading = True
def load_skill_annotations(dataset_root: Path) -> dict | None:
"""Load existing skill annotations from a dataset."""
skills_path = dataset_root / "meta" / "skills.json"
if skills_path.exists():
with open(skills_path) as f:
return json.load(f)
return None
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
"""Main entry point for the skill annotation script."""
parser = argparse.ArgumentParser(
description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent("""\
Examples:
# Annotate a HuggingFace Hub dataset
python annotate.py --repo-id user/dataset --video-key observation.images.base
# Annotate a local dataset with custom batch size
python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16
# Use a specific model
python annotate.py --repo-id user/dataset --video-key observation.images.base \\
--model Qwen/Qwen2-VL-7B-Instruct
# Push annotated dataset to Hub
python annotate.py --repo-id user/dataset --video-key observation.images.base --push-to-hub
"""),
)
# Data source (mutually exclusive)
data_group = parser.add_mutually_exclusive_group(required=True)
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
# Required arguments
parser.add_argument(
"--video-key",
type=str,
required=True,
help="Video observation key (e.g., 'observation.images.base')",
)
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-7B-Instruct",
help="VLM model to use for skill segmentation (default: Qwen/Qwen2-VL-7B-Instruct)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run model on (default: cuda)",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Model dtype (default: bfloat16)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Number of episodes to process in each batch (default: 8)",
)
# Episode selection
parser.add_argument(
"--episodes",
type=int,
nargs="+",
help="Specific episode indices to annotate (default: all)",
)
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip episodes that already have annotations",
)
# Output options
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push annotated dataset to HuggingFace Hub",
)
parser.add_argument(
"--output-path",
type=str,
help="Custom output path for annotations JSON",
)
args = parser.parse_args()
console = Console()
# Validate arguments
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map[args.dtype]
# Load dataset
console.print("[cyan]Loading dataset...[/cyan]")
if args.data_dir:
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir, download_videos=False)
else:
dataset = LeRobotDataset(repo_id=args.repo_id, download_videos=True)
console.print(f"[green]✓ Loaded dataset with {dataset.meta.total_episodes} episodes[/green]")
# Validate video key
if args.video_key not in dataset.meta.video_keys:
available = ", ".join(dataset.meta.video_keys)
console.print(f"[red]Error: Video key '{args.video_key}' not found. Available: {available}[/red]")
return
# Initialize VLM
console.print(f"[cyan]Initializing VLM: {args.model}...[/cyan]")
vlm = get_vlm(args.model, args.device, torch_dtype)
# Create annotator and run annotation
annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size)
console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]")
annotations = annotator.annotate_dataset(
dataset=dataset,
video_key=args.video_key,
episodes=args.episodes,
skip_existing=args.skip_existing,
)
# Save annotations
output_path = Path(args.output_path) if args.output_path else None
save_skill_annotations(dataset, annotations, output_path)
# Summary
total_skills = sum(len(ann.skills) for ann in annotations.values())
console.print(f"\n[bold green]✓ Annotation complete![/bold green]")
console.print(f" Episodes annotated: {len(annotations)}")
console.print(f" Total skills identified: {total_skills}")
# Push to hub if requested
if args.push_to_hub:
if args.data_dir:
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
else:
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
try:
dataset.push_to_hub(push_videos=False)
console.print(f"[green]✓ Pushed to {args.repo_id}[/green]")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")
if __name__ == "__main__":
main()