#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ SARM-Style Subtask Annotation for LeRobot Datasets (Local GPU) This script implements the annotation approach from the SARM paper using local GPU inference: "SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation" Paper: https://arxiv.org/pdf/2509.25358 What it does: 1. Takes videos from a LeRobot dataset 2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur 3. Saves subtask timestamps to the dataset metadata 4. Optionally pushes the annotated dataset to HuggingFace Hub SARM trains reward models that predict: - Stage: Which subtask is currently being executed (discrete classification) - Progress: How far along the subtask we are (continuous 0-1) Requirements: - GPU with sufficient VRAM (16GB+ recommended for 30B model) - transformers, torch, qwen-vl-utils Task-specific subtasks: Each task has a predefined list of subtasks. The model MUST use these exact names to ensure consistency. Usage: # Install dependencies pip install transformers torch qwen-vl-utils accelerate # Sequential processing (single GPU): python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.base \\ --push-to-hub # Parallel processing (4 GPUs): python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.base \\ --num-workers 4 \\ --push-to-hub # Parallel with specific GPU IDs: python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.base \\ --num-workers 2 \\ --gpu-ids 0 2 \\ --push-to-hub """ import argparse import json import time from pathlib import Path from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp import pandas as pd import torch from qwen_vl_utils import process_vision_info from rich.console import Console from rich.panel import Panel from rich.tree import Tree from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp def create_sarm_prompt(subtask_list: list[str]) -> str: """ Create a SARM annotation prompt with a specific subtask list. The prompt instructs the VLM to identify when each subtask occurs in the video, using ONLY the provided subtask names (for consistency across demonstrations). """ subtask_str = "\n".join([f" - {name}" for name in subtask_list]) return f"""# Role You are an expert Robotics Vision System specializing in temporal action localization. Your task is to segment a video of a robot manipulation demonstration into a sequence of distinct, non-overlapping atomic actions. # Input Data ## Allowed Subtask Vocabulary You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones: [ {subtask_str} ] # Constraints & Logic 1. **Continuous Coverage:** The entire video duration (from 00:00 to the final second) must be accounted for. There can be no gaps between tasks. 2. **Boundary Logic:** The `end` timestamp of one task must be the exact `start` timestamp of the next task. 3. **Linear Progression:** The video represents a single successful demonstration. Each subtask from the vocabulary appears exactly once, in logical chronological order. 4. **Format:** Timestamps must be in "MM:SS" format. # Step-by-Step Analysis Process 1. **Visual grounding:** Look for the specific visual state changes that define the transition between tasks (e.g., gripper touching object, object lifting off table). 2. **Define Boundaries:** Determine the specific frame where the motion profile changes to fit the next subtask label. 3. **Fill Gaps:** If there is a pause between meaningful actions, append that time to the *preceding* task to ensure continuous coverage. # Output Format Provide the output in valid JSON format. Structure: { "subtasks": [ { "name": "EXACT_NAME_FROM_LIST", "timestamps": { "start": "MM:SS", "end": "MM:SS" } }, { "name": "EXACT_NAME_FROM_LIST", "timestamps": { "start": "MM:SS", "end": "MM:SS" } } ] } """ class VideoAnnotator: """Annotates robot manipulation videos using local Qwen3-VL model on GPU""" def __init__( self, subtask_list: list[str], model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct", device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16, ): """ Initialize the video annotator with local model. Args: subtask_list: List of allowed subtask names (for consistency) model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct) device: Device to use (cuda, cpu) torch_dtype: Data type for model (bfloat16, float16, float32) """ self.subtask_list = subtask_list self.prompt = create_sarm_prompt(subtask_list) self.console = Console() self.device = device self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]") self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True ) self.processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") def extract_episode_segment( self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1 ) -> Path: """ Extract a specific episode segment from concatenated video. Uses minimal compression to preserve quality for local inference. Args: file_path: Path to the concatenated video file start_timestamp: Starting timestamp in seconds (within this video file) end_timestamp: Ending timestamp in seconds (within this video file) target_fps: Target FPS (default: 1 for faster processing) Returns: Path to extracted video file """ import os import tempfile import subprocess # Create temporary file for extracted video tmp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) tmp_path = Path(tmp_file.name) tmp_file.close() try: # Check if ffmpeg is available subprocess.run(['ffmpeg', '-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) except (subprocess.CalledProcessError, FileNotFoundError): self.console.print("[yellow]Warning: ffmpeg not found, cannot extract episode segment[/yellow]") return file_path try: # Calculate duration duration = end_timestamp - start_timestamp self.console.print(f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]") # Use ffmpeg to extract segment with minimal quality loss cmd = [ 'ffmpeg', '-i', str(file_path), '-ss', str(start_timestamp), '-t', str(duration), '-r', str(target_fps), '-c:v', 'libx264', '-preset', 'ultrafast', '-crf', '23', '-an', '-y', str(tmp_path) ] subprocess.run( cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True ) # Verify the output file was created and is not empty if not tmp_path.exists() or tmp_path.stat().st_size == 0: self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]") if tmp_path.exists(): tmp_path.unlink() raise RuntimeError("FFmpeg produced empty video file") # Show extraction results file_size_mb = tmp_path.stat().st_size / (1024 * 1024) # Fail if file is too small (< 100KB likely means extraction failed) if file_size_mb < 0.1: self.console.print(f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]") tmp_path.unlink() raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)") self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]") return tmp_path except subprocess.CalledProcessError as e: self.console.print(f"[yellow]Warning: ffmpeg failed ({e})[/yellow]") if tmp_path.exists(): tmp_path.unlink() return file_path def annotate( self, file_path: str | Path, fps: int, start_timestamp: float = 0.0, end_timestamp: float | None = None, max_retries: int = 3 ) -> SubtaskAnnotation: """ Annotate a video file or episode segment using local GPU. Args: file_path: Path to the video file (may contain multiple concatenated episodes) fps: Frames per second of the video start_timestamp: Starting timestamp in seconds (within this video file) end_timestamp: Ending timestamp in seconds (within this video file) max_retries: Number of retries if annotation fails Returns: SubtaskAnnotation object with the results """ file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"Video file not found: {file_path}") # Calculate episode duration if end_timestamp is None: import cv2 import os import sys stderr_backup = sys.stderr sys.stderr = open(os.devnull, 'w') try: cap = cv2.VideoCapture(str(file_path)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) video_fps = cap.get(cv2.CAP_PROP_FPS) end_timestamp = total_frames / video_fps if video_fps > 0 else 0 cap.release() finally: sys.stderr.close() sys.stderr = stderr_backup duration_seconds = end_timestamp - start_timestamp duration_mins = int(duration_seconds // 60) duration_secs = int(duration_seconds % 60) duration_str = f"{duration_mins:02d}:{duration_secs:02d}" self.console.print(f"[cyan]Processing episode from concatenated video: {file_path.name}[/cyan]") self.console.print(f"[cyan]Episode timestamps: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration_seconds:.1f}s)[/cyan]") self.console.print(f"[cyan]Episode duration: {duration_str}[/cyan]") # Extract episode segment extracted_path = self.extract_episode_segment( file_path, start_timestamp=start_timestamp, end_timestamp=end_timestamp, target_fps=1 ) is_extracted = extracted_path != file_path try: # Add video duration to prompt prompt_with_duration = f"""{self.prompt} # Video Duration: The video is {duration_str} long ({duration_seconds:.1f} seconds). Your total annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}. Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str}.""" # Prepare messages for the model messages = [ { "role": "user", "content": [ { "type": "video", "video": str(extracted_path), "fps": 1.0, # Sample at 1 FPS for analysis }, {"type": "text", "text": prompt_with_duration}, ], } ] # Generate annotation with retries for attempt in range(max_retries): try: self.console.print(f"[cyan]Generating annotation (attempt {attempt + 1}/{max_retries})...[/cyan]") # Prepare inputs text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(self.device) # Generate with torch.no_grad(): generated_ids = self.model.generate( **inputs, max_new_tokens=2048, temperature=0.1, # Low temperature for consistent output ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] response_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] self.console.print(f"[dim]Raw response: {response_text[:200]}...[/dim]") # Try to extract JSON from response # Sometimes models wrap JSON in markdown code blocks response_text = response_text.strip() if "```json" in response_text: response_text = response_text.split("```json")[1].split("```")[0].strip() elif "```" in response_text: response_text = response_text.split("```")[1].split("```")[0].strip() # Parse response import json try: response_dict = json.loads(response_text) annotation = SubtaskAnnotation.model_validate(response_dict) except json.JSONDecodeError: # Try to find JSON object in response import re json_match = re.search(r'\{.*\}', response_text, re.DOTALL) if json_match: response_dict = json.loads(json_match.group()) annotation = SubtaskAnnotation.model_validate(response_dict) else: raise ValueError("Could not parse JSON from model response") self.console.print("[green]✓ Annotation completed successfully[/green]") return annotation except Exception as e: self.console.print(f"[yellow]⚠ Attempt {attempt + 1} failed: {e}[/yellow]") if attempt == max_retries - 1: raise RuntimeError(f"Failed to annotate after {max_retries} attempts") from e time.sleep(1) finally: # Clean up temporary extracted file if is_extracted and extracted_path.exists(): extracted_path.unlink() def display_annotation(annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int): """Display annotation in a nice tree format with frame indices""" tree = Tree(f"[bold]Episode {episode_idx} - Subtask Annotation[/bold]") # Subtasks subtasks_branch = tree.add(f"[bold cyan]Subtasks ({len(annotation.subtasks)} total)[/bold cyan]") for i, subtask in enumerate(annotation.subtasks, 1): # Calculate frame indices for display start_sec = timestamp_to_seconds(subtask.timestamps.start) end_sec = timestamp_to_seconds(subtask.timestamps.end) start_frame = int(start_sec * fps) end_frame = int(end_sec * fps) subtasks_branch.add( f"{i}. [cyan]{subtask.name}[/cyan]: " f"{subtask.timestamps.start} → {subtask.timestamps.end} " f"[dim](frames {start_frame}-{end_frame})[/dim]" ) console.print(tree) def timestamp_to_seconds(timestamp: str) -> float: """Convert MM:SS or SS timestamp to seconds""" parts = timestamp.split(":") if len(parts) == 2: return int(parts[0]) * 60 + int(parts[1]) else: return int(parts[0]) def save_annotations_to_dataset( dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, ): """ Save annotations to LeRobot dataset parquet format. For each episode, stores subtask annotations with: - subtask_names: list of subtask names - subtask_start_times: list of start times (seconds) - subtask_end_times: list of end times (seconds) - subtask_start_frames: list of start frames - subtask_end_frames: list of end frames """ import pandas as pd import pyarrow.parquet as pq from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes console = Console() # Load existing episodes metadata (returns datasets.Dataset) episodes_dataset = load_episodes(dataset_path) if episodes_dataset is None or len(episodes_dataset) == 0: console.print("[red]Error: No episodes found in dataset[/red]") return # Convert to pandas DataFrame for easier manipulation episodes_df = episodes_dataset.to_pandas() # Add subtask columns to episodes dataframe episodes_df["subtask_names"] = None episodes_df["subtask_start_times"] = None episodes_df["subtask_end_times"] = None episodes_df["subtask_start_frames"] = None episodes_df["subtask_end_frames"] = None # Fill in annotations for ep_idx, annotation in annotations.items(): if ep_idx >= len(episodes_df): console.print(f"[yellow]Warning: Episode {ep_idx} not found in dataset[/yellow]") continue subtask_names = [] start_times = [] end_times = [] start_frames = [] end_frames = [] for subtask in annotation.subtasks: subtask_names.append(subtask.name) # Convert timestamps to seconds start_sec = timestamp_to_seconds(subtask.timestamps.start) end_sec = timestamp_to_seconds(subtask.timestamps.end) start_times.append(start_sec) end_times.append(end_sec) # Calculate frame indices from timestamps and FPS start_frame = int(start_sec * fps) end_frame = int(end_sec * fps) start_frames.append(start_frame) end_frames.append(end_frame) # Store as lists in the dataframe episodes_df.at[ep_idx, "subtask_names"] = subtask_names episodes_df.at[ep_idx, "subtask_start_times"] = start_times episodes_df.at[ep_idx, "subtask_end_times"] = end_times episodes_df.at[ep_idx, "subtask_start_frames"] = start_frames episodes_df.at[ep_idx, "subtask_end_frames"] = end_frames # Group episodes by their chunk and file indices episodes_by_file = {} for ep_idx in episodes_df.index: chunk_idx = episodes_df.loc[ep_idx, "meta/episodes/chunk_index"] file_idx = episodes_df.loc[ep_idx, "meta/episodes/file_index"] key = (chunk_idx, file_idx) if key not in episodes_by_file: episodes_by_file[key] = [] episodes_by_file[key].append(ep_idx) # Write back to parquet files for (chunk_idx, file_idx), ep_indices in episodes_by_file.items(): episodes_path = dataset_path / DEFAULT_EPISODES_PATH.format( chunk_index=chunk_idx, file_index=file_idx ) if not episodes_path.exists(): console.print(f"[yellow]Warning: Episodes file not found: {episodes_path}[/yellow]") continue # Read the existing parquet file file_df = pd.read_parquet(episodes_path) # Add subtask columns if they don't exist for col in ["subtask_names", "subtask_start_times", "subtask_end_times", "subtask_start_frames", "subtask_end_frames"]: if col not in file_df.columns: file_df[col] = None # Update rows that have annotations for ep_idx in ep_indices: if ep_idx in file_df.index and ep_idx in annotations: file_df.at[ep_idx, "subtask_names"] = episodes_df.loc[ep_idx, "subtask_names"] file_df.at[ep_idx, "subtask_start_times"] = episodes_df.loc[ep_idx, "subtask_start_times"] file_df.at[ep_idx, "subtask_end_times"] = episodes_df.loc[ep_idx, "subtask_end_times"] file_df.at[ep_idx, "subtask_start_frames"] = episodes_df.loc[ep_idx, "subtask_start_frames"] file_df.at[ep_idx, "subtask_end_frames"] = episodes_df.loc[ep_idx, "subtask_end_frames"] # Write back to parquet file_df.to_parquet(episodes_path, engine="pyarrow", compression="snappy") console.print(f"[green]✓ Updated {episodes_path.name} with {len([e for e in ep_indices if e in annotations])} annotations[/green]") console.print(f"[bold green]✓ Saved {len(annotations)} episode annotations to parquet files[/bold green]") def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]: """ Load annotations from LeRobot dataset parquet files. Reads subtask annotations from the episodes metadata parquet files. """ from lerobot.datasets.utils import load_episodes episodes_dataset = load_episodes(dataset_path) if episodes_dataset is None or len(episodes_dataset) == 0: return {} # Check if subtask columns exist if "subtask_names" not in episodes_dataset.column_names: return {} # Convert to pandas DataFrame for easier access episodes_df = episodes_dataset.to_pandas() annotations = {} for ep_idx in episodes_df.index: subtask_names = episodes_df.loc[ep_idx, "subtask_names"] # Skip episodes without annotations if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): continue start_times = episodes_df.loc[ep_idx, "subtask_start_times"] end_times = episodes_df.loc[ep_idx, "subtask_end_times"] # Reconstruct SubtaskAnnotation from stored data subtasks = [] for i, name in enumerate(subtask_names): # Convert seconds back to MM:SS format start_sec = int(start_times[i]) end_sec = int(end_times[i]) start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}" end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}" subtasks.append( Subtask( name=name, timestamps=Timestamp(start=start_str, end=end_str) ) ) annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks) return annotations def process_single_episode( ep_idx: int, dataset_root: Path, dataset_meta, video_key: str, fps: int, annotator: VideoAnnotator, console: Console, ) -> tuple[int, SubtaskAnnotation | None, str | None]: """ Process a single episode annotation. Args: ep_idx: Episode index dataset_root: Dataset root path dataset_meta: Dataset metadata video_key: Video key to use fps: FPS of the video annotator: VideoAnnotator instance console: Console for output Returns: Tuple of (episode_index, annotation or None, error message or None) """ try: # Get video path video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key) if not video_path.exists(): return ep_idx, None, f"Video not found: {video_path}" # Get video-specific timestamps (NOT global frame indices) video_path_key = f"videos/{video_key}/from_timestamp" video_path_key_to = f"videos/{video_key}/to_timestamp" start_timestamp = float(dataset_meta.episodes[video_path_key][ep_idx]) end_timestamp = float(dataset_meta.episodes[video_path_key_to][ep_idx]) # Annotate with video-specific timestamps annotation = annotator.annotate( video_path, fps, start_timestamp=start_timestamp, end_timestamp=end_timestamp ) return ep_idx, annotation, None except Exception as e: return ep_idx, None, str(e) def worker_process_episodes( worker_id: int, gpu_id: int, episode_indices: list[int], repo_id: str, video_key: str, subtask_list: list[str], model_name: str, torch_dtype: torch.dtype, ) -> dict[int, SubtaskAnnotation]: """ Worker function for parallel processing across GPUs. Args: worker_id: Worker ID for logging gpu_id: GPU device ID to use episode_indices: List of episode indices to process repo_id: Dataset repo ID video_key: Video key to use subtask_list: List of subtask names model_name: Model name to load torch_dtype: Model dtype Returns: Dictionary of episode_idx -> SubtaskAnnotation """ # Set GPU device device = f"cuda:{gpu_id}" # Initialize console for this worker console = Console() console.print(f"[cyan]Worker {worker_id} starting on GPU {gpu_id} with {len(episode_indices)} episodes[/cyan]") # Load dataset (this is lightweight, just metadata) dataset = LeRobotDataset(repo_id, download_videos=False) fps = dataset.fps # Initialize annotator for this worker annotator = VideoAnnotator( subtask_list=subtask_list, model_name=model_name, device=device, torch_dtype=torch_dtype ) # Process assigned episodes annotations = {} for i, ep_idx in enumerate(episode_indices): console.print(f"[cyan]Worker {worker_id} | Episode {ep_idx} ({i+1}/{len(episode_indices)})[/cyan]") result_ep_idx, annotation, error = process_single_episode( ep_idx, dataset.root, dataset.meta, video_key, fps, annotator, console ) if error: console.print(f"[red]Worker {worker_id} | ✗ Failed episode {result_ep_idx}: {error}[/red]") elif annotation: annotations[result_ep_idx] = annotation console.print(f"[green]Worker {worker_id} | ✓ Completed episode {result_ep_idx}[/green]") console.print(f"[bold green]Worker {worker_id} completed {len(annotations)}/{len(episode_indices)} episodes[/bold green]") return annotations def main(): parser = argparse.ArgumentParser( description="SARM-style subtask annotation using local GPU (Qwen3-VL)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Sequential processing (single GPU): python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.top \\ --push-to-hub # Parallel processing with 4 GPUs (4x speedup): python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.top \\ --num-workers 4 \\ --push-to-hub Performance remarks: - Each worker loads one model instance on its assigned GPU - The 30B model requires ~60GB VRAM per GPU - Use --num-workers N for N GPUs """ ) parser.add_argument( "--repo-id", type=str, required=True, help="HuggingFace dataset repository ID (e.g., 'pepijn223/mydataset')", ) parser.add_argument( "--subtasks", type=str, required=True, help="Comma-separated list of subtask names (e.g., 'reach,grasp,lift,place')", ) parser.add_argument( "--episodes", type=int, nargs="+", default=None, help="Specific episode indices to annotate (default: all episodes)", ) parser.add_argument( "--max-episodes", type=int, default=None, help="Maximum number of episodes to annotate", ) parser.add_argument( "--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="Qwen3-VL model to use (default: Qwen/Qwen3-VL-30B-A3B-Instruct)", ) parser.add_argument( "--skip-existing", action="store_true", help="Skip episodes that already have annotations", ) parser.add_argument( "--video-key", type=str, default=None, help="Camera/video key to use for annotation (e.g., 'observation.images.top'). " "If not specified, uses the first available video key.", ) parser.add_argument( "--push-to-hub", action="store_true", help="Push annotated dataset to HuggingFace Hub", ) parser.add_argument( "--output-repo-id", type=str, default=None, help="Output repository ID for push (default: same as --repo-id)", ) parser.add_argument( "--device", type=str, default="cuda", help="Device to use (cuda, cpu)", ) parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Model dtype (default: bfloat16)", ) parser.add_argument( "--num-workers", type=int, default=1, help="Number of parallel workers for multi-GPU processing (default: 1 for sequential). " "Set to number of GPUs available for parallel processing.", ) parser.add_argument( "--gpu-ids", type=int, nargs="+", default=None, help="Specific GPU IDs to use (e.g., --gpu-ids 0 1 2). " "If not specified, uses GPUs 0 to num-workers-1.", ) parser.add_argument( "--batch-size", type=int, default=1, help="Batch size for processing multiple episodes per inference (experimental, default: 1)", ) args = parser.parse_args() subtask_list = [s.strip() for s in args.subtasks.split(",")] dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } torch_dtype = dtype_map[args.dtype] console = Console() console.print(Panel.fit( "[bold cyan]SARM Subtask Annotation (Local GPU)[/bold cyan]\n" f"Dataset: {args.repo_id}\n" f"Model: {args.model}\n" f"Device: {args.device}\n" f"Subtasks: {', '.join(subtask_list)}", border_style="cyan" )) console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]") dataset = LeRobotDataset(args.repo_id, download_videos=True) fps = dataset.fps console.print(f"[cyan]Dataset FPS: {fps}[/cyan]") # Display available cameras/video keys if len(dataset.meta.video_keys) == 0: console.print("[red]Error: No video keys found in dataset[/red]") return console.print(f"\n[cyan]Available cameras/video keys:[/cyan]") for i, vk in enumerate(dataset.meta.video_keys, 1): console.print(f" {i}. {vk}") # Get video key if args.video_key: if args.video_key not in dataset.meta.video_keys: console.print(f"[red]Error: Video key '{args.video_key}' not found in dataset[/red]") console.print(f"[yellow]Available keys: {', '.join(dataset.meta.video_keys)}[/yellow]") return video_key = args.video_key console.print(f"\n[green]Using specified camera: {video_key}[/green]") else: video_key = dataset.meta.video_keys[0] console.print(f"\n[yellow]No camera specified, using first available: {video_key}[/yellow]") if len(dataset.meta.video_keys) > 1: console.print(f"[yellow]Tip: Use --video-key to specify a different camera[/yellow]") # Determine episodes to annotate if args.episodes: episode_indices = args.episodes else: episode_indices = list(range(dataset.meta.total_episodes)) if args.max_episodes: episode_indices = episode_indices[: args.max_episodes] console.print(f"[cyan]Will annotate {len(episode_indices)} episodes[/cyan]") # Load existing annotations existing_annotations = load_annotations_from_dataset(dataset.root) if args.skip_existing and existing_annotations: console.print(f"[yellow]Found {len(existing_annotations)} existing annotations[/yellow]") episode_indices = [ep for ep in episode_indices if ep not in existing_annotations] console.print(f"[cyan]Will annotate {len(episode_indices)} new episodes[/cyan]") if not episode_indices: console.print("[green]All episodes already annotated![/green]") return # Determine GPU IDs to use if args.gpu_ids: gpu_ids = args.gpu_ids if len(gpu_ids) < args.num_workers: console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {len(gpu_ids)} GPU IDs provided[/yellow]") args.num_workers = len(gpu_ids) else: # Check available GPUs if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() if args.num_workers > num_gpus: console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {num_gpus} GPUs available[/yellow]") args.num_workers = min(args.num_workers, num_gpus) gpu_ids = list(range(args.num_workers)) else: console.print("[yellow]Warning: CUDA not available, using CPU (num_workers will be ignored)[/yellow]") args.num_workers = 1 gpu_ids = [0] # Dummy value for CPU # Annotate episodes - choose sequential or parallel mode annotations = existing_annotations.copy() if args.num_workers > 1: # ===== PARALLEL PROCESSING MODE ===== console.print(f"\n[bold cyan]Using {args.num_workers} parallel workers on GPUs: {gpu_ids}[/bold cyan]") # Split episodes across workers episodes_per_worker = [[] for _ in range(args.num_workers)] for i, ep_idx in enumerate(episode_indices): worker_idx = i % args.num_workers episodes_per_worker[worker_idx].append(ep_idx) # Show distribution for worker_id, episodes in enumerate(episodes_per_worker): console.print(f"[cyan]Worker {worker_id} (GPU {gpu_ids[worker_id]}): {len(episodes)} episodes[/cyan]") # Start parallel processing using ProcessPoolExecutor console.print(f"\n[bold cyan]Starting parallel annotation...[/bold cyan]") # Use 'spawn' method for CUDA compatibility (required for multi-GPU) mp_context = mp.get_context('spawn') with ProcessPoolExecutor(max_workers=args.num_workers, mp_context=mp_context) as executor: # Submit all worker jobs futures = [] for worker_id in range(args.num_workers): if not episodes_per_worker[worker_id]: continue # Skip workers with no episodes future = executor.submit( worker_process_episodes, worker_id, gpu_ids[worker_id], episodes_per_worker[worker_id], args.repo_id, video_key, subtask_list, args.model, torch_dtype, ) futures.append(future) # Collect results as they complete for future in as_completed(futures): try: worker_annotations = future.result() annotations.update(worker_annotations) # Save after each worker completes save_annotations_to_dataset(dataset.root, annotations, fps) console.print(f"[green]✓ Worker completed, saved {len(worker_annotations)} annotations[/green]") except Exception as e: console.print(f"[red]✗ Worker failed: {e}[/red]") console.print(f"\n[bold green]Parallel processing complete! Annotated {len(annotations)} episodes[/bold green]") # Display all annotations for ep_idx in sorted(annotations.keys()): if ep_idx not in existing_annotations: # Only show newly annotated display_annotation(annotations[ep_idx], console, ep_idx, fps) else: # ===== SEQUENTIAL PROCESSING MODE ===== console.print(f"\n[bold cyan]Using sequential processing (single GPU/CPU)[/bold cyan]") # Initialize annotator with subtask list annotator = VideoAnnotator( subtask_list=subtask_list, model_name=args.model, device=args.device, torch_dtype=torch_dtype ) # Process episodes sequentially for i, ep_idx in enumerate(episode_indices): console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]") console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]") console.print(f"[bold cyan]{'=' * 60}[/bold cyan]") result_ep_idx, annotation, error = process_single_episode( ep_idx, dataset.root, dataset.meta, video_key, fps, annotator, console ) if error: console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]") continue elif annotation: annotations[result_ep_idx] = annotation display_annotation(annotation, console, result_ep_idx, fps) save_annotations_to_dataset(dataset.root, annotations, fps) # Compute temporal proportions (key SARM insight) console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]") temporal_proportions = compute_temporal_proportions(annotations, fps) # Save temporal proportions proportions_path = dataset.root / "meta" / "temporal_proportions.json" proportions_path.parent.mkdir(parents=True, exist_ok=True) with open(proportions_path, "w") as f: json.dump(temporal_proportions, f, indent=2) console.print(f"[green]✓ Saved temporal proportions to {proportions_path}[/green]") console.print("\n[cyan]Average temporal proportions:[/cyan]") for name, proportion in sorted(temporal_proportions.items(), key=lambda x: -x[1]): console.print(f" {name}: {proportion:.1%}") # Create summary console.print(f"\n[bold green]{'=' * 60}[/bold green]") console.print(f"[bold green]Annotation Complete![/bold green]") console.print(f"[bold green]{'=' * 60}[/bold green]") console.print(f"Total episodes annotated: {len(annotations)}") console.print(f"Total subtasks found: {sum(len(ann.subtasks) for ann in annotations.values())}") # Push to hub if requested if args.push_to_hub: output_repo = args.output_repo_id if args.output_repo_id else args.repo_id console.print(f"\n[bold cyan]Pushing to HuggingFace Hub: {output_repo}[/bold cyan]") try: dataset.push_to_hub(push_videos=True) console.print(f"[bold green]✓ Successfully pushed to {output_repo}[/bold green]") except Exception as e: console.print(f"[red]✗ Failed to push to hub: {e}[/red]") console.print("[yellow]Annotations are still saved locally[/yellow]") if __name__ == "__main__": main()