mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
973 lines
37 KiB
Python
973 lines
37 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
SARM-Style Subtask Annotation for LeRobot Datasets (Local GPU)
|
|
|
|
This script implements the annotation approach from the SARM paper using local GPU inference:
|
|
"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation"
|
|
Paper: https://arxiv.org/pdf/2509.25358
|
|
|
|
What it does:
|
|
1. Takes videos from a LeRobot dataset
|
|
2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur
|
|
3. Saves subtask timestamps to the dataset metadata
|
|
4. Optionally pushes the annotated dataset to HuggingFace Hub
|
|
|
|
SARM trains reward models that predict:
|
|
- Stage: Which subtask is currently being executed (discrete classification)
|
|
- Progress: How far along the subtask we are (continuous 0-1)
|
|
|
|
Requirements:
|
|
- GPU with sufficient VRAM (16GB+ recommended for 30B model)
|
|
- transformers, torch, qwen-vl-utils
|
|
|
|
Task-specific subtasks: Each task has a predefined list of subtasks. The model MUST use these exact names
|
|
to ensure consistency.
|
|
|
|
Usage:
|
|
# Install dependencies
|
|
pip install transformers torch qwen-vl-utils accelerate
|
|
|
|
# Annotate and push to hub:
|
|
python subtask_annotation.py \\
|
|
--repo-id pepijn223/mydataset \\
|
|
--subtasks "reach,grasp,lift,place" \\
|
|
--video-key observation.images.base \\
|
|
--push-to-hub
|
|
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
import torch
|
|
from pydantic import BaseModel, Field
|
|
from qwen_vl_utils import process_vision_info
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
from rich.tree import Tree
|
|
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
# Pydantic Models for SARM-style Annotation
|
|
class Timestamp(BaseModel):
|
|
"""Timestamp in MM:SS or SS format"""
|
|
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
|
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
|
|
|
|
|
class Subtask(BaseModel):
|
|
"""Individual subtask/stage - must use EXACT names from provided list"""
|
|
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
|
timestamps: Timestamp
|
|
|
|
|
|
class SubtaskAnnotation(BaseModel):
|
|
"""Complete annotation for a robot manipulation episode"""
|
|
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
|
|
|
|
|
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
|
"""
|
|
Create a SARM annotation prompt with a specific subtask list.
|
|
|
|
The prompt instructs the VLM to identify when each subtask occurs in the video,
|
|
using ONLY the provided subtask names (for consistency across demonstrations).
|
|
"""
|
|
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
|
|
|
|
return f"""You are an expert video annotator. Analyze this robot manipulation video and identify when each subtask occurs.
|
|
|
|
WATCH THE ENTIRE VIDEO FIRST:
|
|
|
|
|
|
CRITICAL REQUIREMENTS:
|
|
1. You MUST use ONLY these EXACT subtask names (no variations, no other names):
|
|
{subtask_str}
|
|
2. Identify the start and end timestamp for each subtask that occurs in the video
|
|
3. Subtasks should be in chronological order
|
|
4. Timestamps should be in MM:SS format (e.g., "00:15" for 15 seconds, "01:30" for 1 minute 30 seconds)
|
|
5. Subtasks should cover the entire demonstration without gaps
|
|
6. You MUST watch the COMPLETE video from start to finish before making ANY annotations or conclusions
|
|
7. Do NOT start annotating until you have seen the entire video length
|
|
8. Only after viewing the complete video should you identify the timestamps
|
|
9. EACH SUBTASK HAPPENS ONLY ONCE in the video - do not identify the same subtask multiple times
|
|
10. Note the exact times when each subtask starts and ends, but make sure to cover the ENTIRE video timeline.
|
|
|
|
FORMAT:
|
|
Return a JSON list of subtasks with their timestamps. Each subtask must have:
|
|
- "name": One of the exact names from the list above
|
|
- "timestamps": object with "start" and "end" fields (MM:SS format)
|
|
|
|
Example structure:
|
|
{{
|
|
"subtasks": [
|
|
{{"name": "reach_to_object", "timestamps": {{"start": "00:00", "end": "00:05"}}}},
|
|
{{"name": "grasp_object", "timestamps": {{"start": "00:05", "end": "00:08"}}}},
|
|
...
|
|
]
|
|
}}
|
|
|
|
Remember: Use ONLY the subtask names provided above, and cover the ENTIRE video timeline."""
|
|
|
|
|
|
class VideoAnnotator:
|
|
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
|
|
|
|
def __init__(
|
|
self,
|
|
subtask_list: list[str],
|
|
model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
|
device: str = "cuda",
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
):
|
|
"""
|
|
Initialize the video annotator with local model.
|
|
|
|
Args:
|
|
subtask_list: List of allowed subtask names (for consistency)
|
|
model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct)
|
|
device: Device to use (cuda, cpu)
|
|
torch_dtype: Data type for model (bfloat16, float16, float32)
|
|
"""
|
|
self.subtask_list = subtask_list
|
|
self.prompt = create_sarm_prompt(subtask_list)
|
|
self.console = Console()
|
|
self.device = device
|
|
|
|
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
|
|
|
# Load model and processor
|
|
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch_dtype,
|
|
device_map=device,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
self.processor = AutoProcessor.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
|
|
|
def extract_episode_segment(
|
|
self,
|
|
file_path: Path,
|
|
start_timestamp: float,
|
|
end_timestamp: float,
|
|
target_fps: int = 2
|
|
) -> Path:
|
|
"""
|
|
Extract a specific episode segment from concatenated video.
|
|
Uses minimal compression to preserve quality for local inference.
|
|
|
|
Args:
|
|
file_path: Path to the concatenated video file
|
|
start_timestamp: Starting timestamp in seconds (within this video file)
|
|
end_timestamp: Ending timestamp in seconds (within this video file)
|
|
target_fps: Target FPS (default: 2 for faster processing)
|
|
|
|
Returns:
|
|
Path to extracted video file
|
|
"""
|
|
import os
|
|
import tempfile
|
|
import subprocess
|
|
|
|
# Create temporary file for extracted video
|
|
tmp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
|
tmp_path = Path(tmp_file.name)
|
|
tmp_file.close()
|
|
|
|
try:
|
|
# Check if ffmpeg is available
|
|
subprocess.run(['ffmpeg', '-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
self.console.print("[yellow]Warning: ffmpeg not found, cannot extract episode segment[/yellow]")
|
|
return file_path
|
|
|
|
try:
|
|
# Calculate duration
|
|
duration = end_timestamp - start_timestamp
|
|
|
|
self.console.print(f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]")
|
|
|
|
# Use ffmpeg to extract segment with minimal quality loss
|
|
cmd = [
|
|
'ffmpeg',
|
|
'-i', str(file_path),
|
|
'-ss', str(start_timestamp), # Start time
|
|
'-t', str(duration), # Duration
|
|
'-r', str(target_fps), # Output FPS
|
|
'-c:v', 'libx264', # H.264 codec
|
|
'-preset', 'ultrafast', # Faster encoding
|
|
'-crf', '23', # Better quality (lower = better)
|
|
'-an', # Remove audio
|
|
'-y', # Overwrite output file
|
|
str(tmp_path)
|
|
]
|
|
|
|
# Run ffmpeg (suppress output)
|
|
subprocess.run(
|
|
cmd,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
check=True
|
|
)
|
|
|
|
# Verify the output file was created and is not empty
|
|
if not tmp_path.exists() or tmp_path.stat().st_size == 0:
|
|
self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]")
|
|
if tmp_path.exists():
|
|
tmp_path.unlink()
|
|
raise RuntimeError("FFmpeg produced empty video file")
|
|
|
|
# Show extraction results
|
|
file_size_mb = tmp_path.stat().st_size / (1024 * 1024)
|
|
|
|
# Fail if file is too small (< 100KB likely means extraction failed)
|
|
if file_size_mb < 0.1:
|
|
self.console.print(f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]")
|
|
tmp_path.unlink()
|
|
raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
|
|
|
|
self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]")
|
|
|
|
return tmp_path
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
self.console.print(f"[yellow]Warning: ffmpeg failed ({e})[/yellow]")
|
|
if tmp_path.exists():
|
|
tmp_path.unlink()
|
|
return file_path
|
|
|
|
def annotate(
|
|
self,
|
|
file_path: str | Path,
|
|
fps: int,
|
|
start_timestamp: float = 0.0,
|
|
end_timestamp: float | None = None,
|
|
max_retries: int = 3
|
|
) -> SubtaskAnnotation:
|
|
"""
|
|
Annotate a video file or episode segment using local GPU.
|
|
|
|
Args:
|
|
file_path: Path to the video file (may contain multiple concatenated episodes)
|
|
fps: Frames per second of the video
|
|
start_timestamp: Starting timestamp in seconds (within this video file)
|
|
end_timestamp: Ending timestamp in seconds (within this video file)
|
|
max_retries: Number of retries if annotation fails
|
|
|
|
Returns:
|
|
SubtaskAnnotation object with the results
|
|
"""
|
|
file_path = Path(file_path)
|
|
if not file_path.exists():
|
|
raise FileNotFoundError(f"Video file not found: {file_path}")
|
|
|
|
# Calculate episode duration
|
|
if end_timestamp is None:
|
|
# Get video metadata (suppress AV1 warnings)
|
|
import cv2
|
|
import os
|
|
import sys
|
|
|
|
stderr_backup = sys.stderr
|
|
sys.stderr = open(os.devnull, 'w')
|
|
|
|
try:
|
|
cap = cv2.VideoCapture(str(file_path))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
end_timestamp = total_frames / video_fps if video_fps > 0 else 0
|
|
cap.release()
|
|
finally:
|
|
sys.stderr.close()
|
|
sys.stderr = stderr_backup
|
|
|
|
duration_seconds = end_timestamp - start_timestamp
|
|
|
|
duration_mins = int(duration_seconds // 60)
|
|
duration_secs = int(duration_seconds % 60)
|
|
duration_str = f"{duration_mins:02d}:{duration_secs:02d}"
|
|
|
|
self.console.print(f"[cyan]Processing episode from concatenated video: {file_path.name}[/cyan]")
|
|
self.console.print(f"[cyan]Episode timestamps: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration_seconds:.1f}s)[/cyan]")
|
|
self.console.print(f"[cyan]Episode duration: {duration_str}[/cyan]")
|
|
|
|
# Extract episode segment
|
|
extracted_path = self.extract_episode_segment(
|
|
file_path,
|
|
start_timestamp=start_timestamp,
|
|
end_timestamp=end_timestamp,
|
|
target_fps=2 # 2 FPS is good balance for VLM
|
|
)
|
|
is_extracted = extracted_path != file_path
|
|
|
|
try:
|
|
# Add video duration to prompt
|
|
prompt_with_duration = f"""{self.prompt}
|
|
|
|
CRITICAL - VIDEO DURATION:
|
|
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
|
|
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str} or very close to it."""
|
|
|
|
# Prepare messages for the model
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "video",
|
|
"video": str(extracted_path),
|
|
"fps": 1.0, # Sample at 1 FPS for analysis
|
|
},
|
|
{"type": "text", "text": prompt_with_duration},
|
|
],
|
|
}
|
|
]
|
|
|
|
# Generate annotation with retries
|
|
for attempt in range(max_retries):
|
|
try:
|
|
self.console.print(f"[cyan]Generating annotation (attempt {attempt + 1}/{max_retries})...[/cyan]")
|
|
|
|
# Prepare inputs
|
|
text = self.processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages)
|
|
|
|
inputs = self.processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
videos=video_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = inputs.to(self.device)
|
|
|
|
# Generate
|
|
with torch.no_grad():
|
|
generated_ids = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=2048,
|
|
temperature=0.1, # Low temperature for consistent output
|
|
)
|
|
|
|
generated_ids_trimmed = [
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
]
|
|
|
|
response_text = self.processor.batch_decode(
|
|
generated_ids_trimmed,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False
|
|
)[0]
|
|
|
|
self.console.print(f"[dim]Raw response: {response_text[:200]}...[/dim]")
|
|
|
|
# Try to extract JSON from response
|
|
# Sometimes models wrap JSON in markdown code blocks
|
|
response_text = response_text.strip()
|
|
if "```json" in response_text:
|
|
response_text = response_text.split("```json")[1].split("```")[0].strip()
|
|
elif "```" in response_text:
|
|
response_text = response_text.split("```")[1].split("```")[0].strip()
|
|
|
|
# Parse response
|
|
import json
|
|
try:
|
|
response_dict = json.loads(response_text)
|
|
annotation = SubtaskAnnotation.model_validate(response_dict)
|
|
except json.JSONDecodeError:
|
|
# Try to find JSON object in response
|
|
import re
|
|
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
|
if json_match:
|
|
response_dict = json.loads(json_match.group())
|
|
annotation = SubtaskAnnotation.model_validate(response_dict)
|
|
else:
|
|
raise ValueError("Could not parse JSON from model response")
|
|
|
|
self.console.print("[green]✓ Annotation completed successfully[/green]")
|
|
return annotation
|
|
|
|
except Exception as e:
|
|
self.console.print(f"[yellow]⚠ Attempt {attempt + 1} failed: {e}[/yellow]")
|
|
if attempt == max_retries - 1:
|
|
raise RuntimeError(f"Failed to annotate after {max_retries} attempts") from e
|
|
time.sleep(1)
|
|
|
|
finally:
|
|
# Clean up temporary extracted file
|
|
if is_extracted and extracted_path.exists():
|
|
extracted_path.unlink()
|
|
|
|
|
|
def display_annotation(annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int):
|
|
"""Display annotation in a nice tree format with frame indices"""
|
|
tree = Tree(f"[bold]Episode {episode_idx} - Subtask Annotation[/bold]")
|
|
|
|
# Subtasks
|
|
subtasks_branch = tree.add(f"[bold cyan]Subtasks ({len(annotation.subtasks)} total)[/bold cyan]")
|
|
for i, subtask in enumerate(annotation.subtasks, 1):
|
|
# Calculate frame indices for display
|
|
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
|
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
|
start_frame = int(start_sec * fps)
|
|
end_frame = int(end_sec * fps)
|
|
|
|
subtasks_branch.add(
|
|
f"{i}. [cyan]{subtask.name}[/cyan]: "
|
|
f"{subtask.timestamps.start} → {subtask.timestamps.end} "
|
|
f"[dim](frames {start_frame}-{end_frame})[/dim]"
|
|
)
|
|
|
|
console.print(tree)
|
|
|
|
|
|
def timestamp_to_seconds(timestamp: str) -> float:
|
|
"""Convert MM:SS or SS timestamp to seconds"""
|
|
parts = timestamp.split(":")
|
|
if len(parts) == 2:
|
|
return int(parts[0]) * 60 + int(parts[1])
|
|
else:
|
|
return int(parts[0])
|
|
|
|
|
|
def save_annotations_to_dataset(
|
|
dataset_path: Path,
|
|
annotations: dict[int, SubtaskAnnotation],
|
|
fps: int,
|
|
):
|
|
"""
|
|
Save annotations to LeRobot dataset parquet format.
|
|
|
|
For each episode, stores subtask annotations with:
|
|
- subtask_names: list of subtask names
|
|
- subtask_start_times: list of start times (seconds)
|
|
- subtask_end_times: list of end times (seconds)
|
|
- subtask_start_frames: list of start frames
|
|
- subtask_end_frames: list of end frames
|
|
"""
|
|
import pandas as pd
|
|
import pyarrow.parquet as pq
|
|
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
|
|
|
|
console = Console()
|
|
|
|
# Load existing episodes metadata (returns datasets.Dataset)
|
|
episodes_dataset = load_episodes(dataset_path)
|
|
|
|
if episodes_dataset is None or len(episodes_dataset) == 0:
|
|
console.print("[red]Error: No episodes found in dataset[/red]")
|
|
return
|
|
|
|
# Convert to pandas DataFrame for easier manipulation
|
|
episodes_df = episodes_dataset.to_pandas()
|
|
|
|
# Add subtask columns to episodes dataframe
|
|
episodes_df["subtask_names"] = None
|
|
episodes_df["subtask_start_times"] = None
|
|
episodes_df["subtask_end_times"] = None
|
|
episodes_df["subtask_start_frames"] = None
|
|
episodes_df["subtask_end_frames"] = None
|
|
|
|
# Fill in annotations
|
|
for ep_idx, annotation in annotations.items():
|
|
if ep_idx >= len(episodes_df):
|
|
console.print(f"[yellow]Warning: Episode {ep_idx} not found in dataset[/yellow]")
|
|
continue
|
|
|
|
subtask_names = []
|
|
start_times = []
|
|
end_times = []
|
|
start_frames = []
|
|
end_frames = []
|
|
|
|
for subtask in annotation.subtasks:
|
|
subtask_names.append(subtask.name)
|
|
|
|
# Convert timestamps to seconds
|
|
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
|
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
|
start_times.append(start_sec)
|
|
end_times.append(end_sec)
|
|
|
|
# Calculate frame indices from timestamps and FPS
|
|
start_frame = int(start_sec * fps)
|
|
end_frame = int(end_sec * fps)
|
|
start_frames.append(start_frame)
|
|
end_frames.append(end_frame)
|
|
|
|
# Store as lists in the dataframe
|
|
episodes_df.at[ep_idx, "subtask_names"] = subtask_names
|
|
episodes_df.at[ep_idx, "subtask_start_times"] = start_times
|
|
episodes_df.at[ep_idx, "subtask_end_times"] = end_times
|
|
episodes_df.at[ep_idx, "subtask_start_frames"] = start_frames
|
|
episodes_df.at[ep_idx, "subtask_end_frames"] = end_frames
|
|
|
|
# Group episodes by their chunk and file indices
|
|
episodes_by_file = {}
|
|
for ep_idx in episodes_df.index:
|
|
chunk_idx = episodes_df.loc[ep_idx, "meta/episodes/chunk_index"]
|
|
file_idx = episodes_df.loc[ep_idx, "meta/episodes/file_index"]
|
|
key = (chunk_idx, file_idx)
|
|
|
|
if key not in episodes_by_file:
|
|
episodes_by_file[key] = []
|
|
episodes_by_file[key].append(ep_idx)
|
|
|
|
# Write back to parquet files
|
|
for (chunk_idx, file_idx), ep_indices in episodes_by_file.items():
|
|
episodes_path = dataset_path / DEFAULT_EPISODES_PATH.format(
|
|
chunk_index=chunk_idx, file_index=file_idx
|
|
)
|
|
|
|
if not episodes_path.exists():
|
|
console.print(f"[yellow]Warning: Episodes file not found: {episodes_path}[/yellow]")
|
|
continue
|
|
|
|
# Read the existing parquet file
|
|
file_df = pd.read_parquet(episodes_path)
|
|
|
|
# Add subtask columns if they don't exist
|
|
for col in ["subtask_names", "subtask_start_times", "subtask_end_times",
|
|
"subtask_start_frames", "subtask_end_frames"]:
|
|
if col not in file_df.columns:
|
|
file_df[col] = None
|
|
|
|
# Update rows that have annotations
|
|
for ep_idx in ep_indices:
|
|
if ep_idx in file_df.index and ep_idx in annotations:
|
|
file_df.at[ep_idx, "subtask_names"] = episodes_df.loc[ep_idx, "subtask_names"]
|
|
file_df.at[ep_idx, "subtask_start_times"] = episodes_df.loc[ep_idx, "subtask_start_times"]
|
|
file_df.at[ep_idx, "subtask_end_times"] = episodes_df.loc[ep_idx, "subtask_end_times"]
|
|
file_df.at[ep_idx, "subtask_start_frames"] = episodes_df.loc[ep_idx, "subtask_start_frames"]
|
|
file_df.at[ep_idx, "subtask_end_frames"] = episodes_df.loc[ep_idx, "subtask_end_frames"]
|
|
|
|
# Write back to parquet
|
|
file_df.to_parquet(episodes_path, engine="pyarrow", compression="snappy")
|
|
console.print(f"[green]✓ Updated {episodes_path.name} with {len([e for e in ep_indices if e in annotations])} annotations[/green]")
|
|
|
|
console.print(f"[bold green]✓ Saved {len(annotations)} episode annotations to parquet files[/bold green]")
|
|
|
|
|
|
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
|
|
"""
|
|
Load annotations from LeRobot dataset parquet files.
|
|
|
|
Reads subtask annotations from the episodes metadata parquet files.
|
|
"""
|
|
from lerobot.datasets.utils import load_episodes
|
|
|
|
episodes_dataset = load_episodes(dataset_path)
|
|
|
|
if episodes_dataset is None or len(episodes_dataset) == 0:
|
|
return {}
|
|
|
|
# Check if subtask columns exist
|
|
if "subtask_names" not in episodes_dataset.column_names:
|
|
return {}
|
|
|
|
# Convert to pandas DataFrame for easier access
|
|
episodes_df = episodes_dataset.to_pandas()
|
|
|
|
annotations = {}
|
|
|
|
for ep_idx in episodes_df.index:
|
|
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
|
|
|
|
# Skip episodes without annotations
|
|
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
|
continue
|
|
|
|
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
|
|
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
|
|
|
|
# Reconstruct SubtaskAnnotation from stored data
|
|
subtasks = []
|
|
for i, name in enumerate(subtask_names):
|
|
# Convert seconds back to MM:SS format
|
|
start_sec = int(start_times[i])
|
|
end_sec = int(end_times[i])
|
|
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
|
|
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
|
|
|
|
subtasks.append(
|
|
Subtask(
|
|
name=name,
|
|
timestamps=Timestamp(start=start_str, end=end_str)
|
|
)
|
|
)
|
|
|
|
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
|
|
|
|
return annotations
|
|
|
|
|
|
def process_single_episode(
|
|
ep_idx: int,
|
|
dataset_root: Path,
|
|
dataset_meta,
|
|
video_key: str,
|
|
fps: int,
|
|
annotator: VideoAnnotator,
|
|
console: Console,
|
|
) -> tuple[int, SubtaskAnnotation | None, str | None]:
|
|
"""
|
|
Process a single episode annotation.
|
|
|
|
Args:
|
|
ep_idx: Episode index
|
|
dataset_root: Dataset root path
|
|
dataset_meta: Dataset metadata
|
|
video_key: Video key to use
|
|
fps: FPS of the video
|
|
annotator: VideoAnnotator instance
|
|
console: Console for output
|
|
|
|
Returns:
|
|
Tuple of (episode_index, annotation or None, error message or None)
|
|
"""
|
|
try:
|
|
# Get video path
|
|
video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key)
|
|
|
|
if not video_path.exists():
|
|
return ep_idx, None, f"Video not found: {video_path}"
|
|
|
|
# Get video-specific timestamps (NOT global frame indices)
|
|
video_path_key = f"videos/{video_key}/from_timestamp"
|
|
video_path_key_to = f"videos/{video_key}/to_timestamp"
|
|
|
|
start_timestamp = float(dataset_meta.episodes[video_path_key][ep_idx])
|
|
end_timestamp = float(dataset_meta.episodes[video_path_key_to][ep_idx])
|
|
|
|
# Annotate with video-specific timestamps
|
|
annotation = annotator.annotate(
|
|
video_path,
|
|
fps,
|
|
start_timestamp=start_timestamp,
|
|
end_timestamp=end_timestamp
|
|
)
|
|
|
|
return ep_idx, annotation, None
|
|
|
|
except Exception as e:
|
|
return ep_idx, None, str(e)
|
|
|
|
|
|
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
|
"""
|
|
Compute average temporal proportion for each subtask across all episodes.
|
|
This is the key insight from SARM - use semantic subtasks instead of frame indices.
|
|
"""
|
|
# Collect all proportions per subtask
|
|
subtask_proportions = {}
|
|
|
|
for annotation in annotations.values():
|
|
# Calculate total episode duration
|
|
total_duration = 0
|
|
durations = {}
|
|
|
|
for subtask in annotation.subtasks:
|
|
# Parse timestamps
|
|
start_parts = subtask.timestamps.start.split(":")
|
|
end_parts = subtask.timestamps.end.split(":")
|
|
|
|
if len(start_parts) == 2:
|
|
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
|
|
else:
|
|
start_seconds = int(start_parts[0])
|
|
|
|
if len(end_parts) == 2:
|
|
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
|
|
else:
|
|
end_seconds = int(end_parts[0])
|
|
|
|
duration = end_seconds - start_seconds
|
|
durations[subtask.name] = duration
|
|
total_duration += duration
|
|
|
|
# Calculate proportions for this episode
|
|
if total_duration > 0:
|
|
for name, duration in durations.items():
|
|
if name not in subtask_proportions:
|
|
subtask_proportions[name] = []
|
|
subtask_proportions[name].append(duration / total_duration)
|
|
|
|
# Average across episodes
|
|
avg_proportions = {
|
|
name: sum(props) / len(props)
|
|
for name, props in subtask_proportions.items()
|
|
}
|
|
|
|
# Normalize to sum to 1.0
|
|
total = sum(avg_proportions.values())
|
|
if total > 0:
|
|
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
|
|
|
return avg_proportions
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# List available cameras:
|
|
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --max-episodes 0
|
|
|
|
# Annotate with specific camera:
|
|
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --push-to-hub
|
|
|
|
# Use custom model:
|
|
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --model Qwen/Qwen3-VL-30B-A3B-Instruct --push-to-hub
|
|
|
|
Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memory.
|
|
"""
|
|
)
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="HuggingFace dataset repository ID (e.g., 'pepijn223/mydataset')",
|
|
)
|
|
parser.add_argument(
|
|
"--subtasks",
|
|
type=str,
|
|
required=True,
|
|
help="Comma-separated list of subtask names (e.g., 'reach,grasp,lift,place')",
|
|
)
|
|
parser.add_argument(
|
|
"--episodes",
|
|
type=int,
|
|
nargs="+",
|
|
default=None,
|
|
help="Specific episode indices to annotate (default: all episodes)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-episodes",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of episodes to annotate",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
|
help="Qwen3-VL model to use (default: Qwen/Qwen3-VL-30B-A3B-Instruct)",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-existing",
|
|
action="store_true",
|
|
help="Skip episodes that already have annotations",
|
|
)
|
|
parser.add_argument(
|
|
"--video-key",
|
|
type=str,
|
|
default=None,
|
|
help="Camera/video key to use for annotation (e.g., 'observation.images.top'). "
|
|
"If not specified, uses the first available video key.",
|
|
)
|
|
parser.add_argument(
|
|
"--push-to-hub",
|
|
action="store_true",
|
|
help="Push annotated dataset to HuggingFace Hub",
|
|
)
|
|
parser.add_argument(
|
|
"--output-repo-id",
|
|
type=str,
|
|
default=None,
|
|
help="Output repository ID for push (default: same as --repo-id)",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda",
|
|
help="Device to use (cuda, cpu)",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default="bfloat16",
|
|
choices=["bfloat16", "float16", "float32"],
|
|
help="Model dtype (default: bfloat16)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Parse subtask list
|
|
subtask_list = [s.strip() for s in args.subtasks.split(",")]
|
|
|
|
# Parse dtype
|
|
dtype_map = {
|
|
"bfloat16": torch.bfloat16,
|
|
"float16": torch.float16,
|
|
"float32": torch.float32,
|
|
}
|
|
torch_dtype = dtype_map[args.dtype]
|
|
|
|
console = Console()
|
|
console.print(Panel.fit(
|
|
"[bold cyan]SARM Subtask Annotation (Local GPU)[/bold cyan]\n"
|
|
f"Dataset: {args.repo_id}\n"
|
|
f"Model: {args.model}\n"
|
|
f"Device: {args.device}\n"
|
|
f"Subtasks: {', '.join(subtask_list)}",
|
|
border_style="cyan"
|
|
))
|
|
|
|
# Load dataset
|
|
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
|
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
|
|
|
# Get FPS from dataset
|
|
fps = dataset.fps
|
|
console.print(f"[cyan]Dataset FPS: {fps}[/cyan]")
|
|
|
|
# Display available cameras/video keys
|
|
if len(dataset.meta.video_keys) == 0:
|
|
console.print("[red]Error: No video keys found in dataset[/red]")
|
|
return
|
|
|
|
console.print(f"\n[cyan]Available cameras/video keys:[/cyan]")
|
|
for i, vk in enumerate(dataset.meta.video_keys, 1):
|
|
console.print(f" {i}. {vk}")
|
|
|
|
# Get video key
|
|
if args.video_key:
|
|
if args.video_key not in dataset.meta.video_keys:
|
|
console.print(f"[red]Error: Video key '{args.video_key}' not found in dataset[/red]")
|
|
console.print(f"[yellow]Available keys: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
|
return
|
|
video_key = args.video_key
|
|
console.print(f"\n[green]Using specified camera: {video_key}[/green]")
|
|
else:
|
|
video_key = dataset.meta.video_keys[0]
|
|
console.print(f"\n[yellow]No camera specified, using first available: {video_key}[/yellow]")
|
|
if len(dataset.meta.video_keys) > 1:
|
|
console.print(f"[yellow]Tip: Use --video-key to specify a different camera[/yellow]")
|
|
|
|
# Determine episodes to annotate
|
|
if args.episodes:
|
|
episode_indices = args.episodes
|
|
else:
|
|
episode_indices = list(range(dataset.meta.total_episodes))
|
|
if args.max_episodes:
|
|
episode_indices = episode_indices[: args.max_episodes]
|
|
|
|
console.print(f"[cyan]Will annotate {len(episode_indices)} episodes[/cyan]")
|
|
|
|
# Load existing annotations
|
|
existing_annotations = load_annotations_from_dataset(dataset.root)
|
|
|
|
if args.skip_existing and existing_annotations:
|
|
console.print(f"[yellow]Found {len(existing_annotations)} existing annotations[/yellow]")
|
|
episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
|
|
console.print(f"[cyan]Will annotate {len(episode_indices)} new episodes[/cyan]")
|
|
|
|
if not episode_indices:
|
|
console.print("[green]All episodes already annotated![/green]")
|
|
return
|
|
|
|
# Initialize annotator with subtask list
|
|
annotator = VideoAnnotator(
|
|
subtask_list=subtask_list,
|
|
model_name=args.model,
|
|
device=args.device,
|
|
torch_dtype=torch_dtype
|
|
)
|
|
|
|
# Annotate episodes (sequential processing)
|
|
annotations = existing_annotations.copy()
|
|
|
|
for i, ep_idx in enumerate(episode_indices):
|
|
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]")
|
|
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]")
|
|
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
|
|
|
|
result_ep_idx, annotation, error = process_single_episode(
|
|
ep_idx,
|
|
dataset.root,
|
|
dataset.meta,
|
|
video_key,
|
|
fps,
|
|
annotator,
|
|
console
|
|
)
|
|
|
|
if error:
|
|
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]")
|
|
continue
|
|
elif annotation:
|
|
annotations[result_ep_idx] = annotation
|
|
display_annotation(annotation, console, result_ep_idx, fps)
|
|
save_annotations_to_dataset(dataset.root, annotations, fps)
|
|
|
|
# Compute temporal proportions (key SARM insight)
|
|
console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]")
|
|
temporal_proportions = compute_temporal_proportions(annotations, fps)
|
|
|
|
# Save temporal proportions
|
|
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
|
proportions_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(proportions_path, "w") as f:
|
|
json.dump(temporal_proportions, f, indent=2)
|
|
|
|
console.print(f"[green]✓ Saved temporal proportions to {proportions_path}[/green]")
|
|
console.print("\n[cyan]Average temporal proportions:[/cyan]")
|
|
for name, proportion in sorted(temporal_proportions.items(), key=lambda x: -x[1]):
|
|
console.print(f" {name}: {proportion:.1%}")
|
|
|
|
# Create summary
|
|
console.print(f"\n[bold green]{'=' * 60}[/bold green]")
|
|
console.print(f"[bold green]Annotation Complete![/bold green]")
|
|
console.print(f"[bold green]{'=' * 60}[/bold green]")
|
|
console.print(f"Total episodes annotated: {len(annotations)}")
|
|
console.print(f"Total subtasks found: {sum(len(ann.subtasks) for ann in annotations.values())}")
|
|
|
|
# Push to hub if requested
|
|
if args.push_to_hub:
|
|
output_repo = args.output_repo_id if args.output_repo_id else args.repo_id
|
|
console.print(f"\n[bold cyan]Pushing to HuggingFace Hub: {output_repo}[/bold cyan]")
|
|
|
|
try:
|
|
dataset.push_to_hub(push_videos=True)
|
|
console.print(f"[bold green]✓ Successfully pushed to {output_repo}[/bold green]")
|
|
except Exception as e:
|
|
console.print(f"[red]✗ Failed to push to hub: {e}[/red]")
|
|
console.print("[yellow]Annotations are still saved locally[/yellow]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|