mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
1189 lines
44 KiB
Python
1189 lines
44 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""
|
||
SARM-Style Subtask Annotation for LeRobot Datasets (Local GPU)
|
||
|
||
This script implements the annotation approach from the SARM paper using local GPU inference:
|
||
"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation"
|
||
Paper: https://arxiv.org/pdf/2509.25358
|
||
|
||
What it does:
|
||
1. Takes videos from a LeRobot dataset
|
||
2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur
|
||
3. Saves subtask timestamps to the dataset metadata
|
||
4. Optionally pushes the annotated dataset to HuggingFace Hub
|
||
|
||
SARM trains reward models that predict:
|
||
- Stage: Which subtask is currently being executed (discrete classification)
|
||
- Progress: How far along the subtask we are (continuous 0-1)
|
||
|
||
Requirements:
|
||
- GPU with sufficient VRAM (16GB+ recommended for 30B model)
|
||
- transformers, torch, qwen-vl-utils
|
||
|
||
Task-specific subtasks: Each task has a predefined list of subtasks. The model MUST use these exact names
|
||
to ensure consistency.
|
||
|
||
Usage:
|
||
# Install dependencies
|
||
pip install transformers torch qwen-vl-utils accelerate
|
||
|
||
# Sequential processing (single GPU):
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.base \\
|
||
--push-to-hub
|
||
|
||
# Parallel processing (4 GPUs):
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.base \\
|
||
--num-workers 4 \\
|
||
--push-to-hub
|
||
|
||
# Parallel with specific GPU IDs:
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.base \\
|
||
--num-workers 2 \\
|
||
--gpu-ids 0 2 \\
|
||
--push-to-hub
|
||
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
import multiprocessing as mp
|
||
|
||
import pandas as pd
|
||
import torch
|
||
from pydantic import BaseModel, Field
|
||
from qwen_vl_utils import process_vision_info
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.tree import Tree
|
||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||
|
||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||
|
||
# Pydantic Models for SARM-style Annotation
|
||
class Timestamp(BaseModel):
|
||
"""Timestamp in MM:SS or SS format"""
|
||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||
|
||
|
||
class Subtask(BaseModel):
|
||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||
timestamps: Timestamp
|
||
|
||
|
||
class SubtaskAnnotation(BaseModel):
|
||
"""Complete annotation for a robot manipulation episode"""
|
||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||
|
||
|
||
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||
"""
|
||
Create a SARM annotation prompt with a specific subtask list.
|
||
|
||
The prompt instructs the VLM to identify when each subtask occurs in the video,
|
||
using ONLY the provided subtask names (for consistency across demonstrations).
|
||
"""
|
||
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
|
||
|
||
return f"""You are an expert video annotator. Analyze this robot manipulation video and identify when each subtask occurs.
|
||
|
||
WATCH THE ENTIRE VIDEO FIRST:
|
||
|
||
|
||
CRITICAL REQUIREMENTS:
|
||
1. You MUST use ONLY these EXACT subtask names (no variations, no other names):
|
||
{subtask_str}
|
||
2. Identify the start and end timestamp for each subtask that occurs in the video
|
||
3. Subtasks should be in chronological order
|
||
4. Timestamps should be in MM:SS format (e.g., "00:15" for 15 seconds, "01:30" for 1 minute 30 seconds)
|
||
5. Subtasks should cover the entire demonstration without gaps
|
||
6. You MUST watch the COMPLETE video from start to finish before making ANY annotations or conclusions
|
||
7. Do NOT start annotating until you have seen the entire video length
|
||
8. Only after viewing the complete video should you identify the timestamps
|
||
9. EACH SUBTASK HAPPENS ONLY ONCE in the video - do not identify the same subtask multiple times
|
||
10. Note the exact times when each subtask starts and ends, but make sure to cover the ENTIRE video timeline.
|
||
|
||
FORMAT:
|
||
Return a JSON list of subtasks with their timestamps. Each subtask must have:
|
||
- "name": One of the exact names from the list above
|
||
- "timestamps": object with "start" and "end" fields (MM:SS format)
|
||
|
||
Example structure:
|
||
{{
|
||
"subtasks": [
|
||
{{"name": "reach_to_object", "timestamps": {{"start": "00:00", "end": "00:05"}}}},
|
||
{{"name": "grasp_object", "timestamps": {{"start": "00:05", "end": "00:08"}}}},
|
||
...
|
||
]
|
||
}}
|
||
|
||
Remember: Use ONLY the subtask names provided above, and cover the ENTIRE video timeline."""
|
||
|
||
|
||
class VideoAnnotator:
|
||
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
|
||
|
||
def __init__(
|
||
self,
|
||
subtask_list: list[str],
|
||
model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||
device: str = "cuda",
|
||
torch_dtype: torch.dtype = torch.bfloat16,
|
||
):
|
||
"""
|
||
Initialize the video annotator with local model.
|
||
|
||
Args:
|
||
subtask_list: List of allowed subtask names (for consistency)
|
||
model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct)
|
||
device: Device to use (cuda, cpu)
|
||
torch_dtype: Data type for model (bfloat16, float16, float32)
|
||
"""
|
||
self.subtask_list = subtask_list
|
||
self.prompt = create_sarm_prompt(subtask_list)
|
||
self.console = Console()
|
||
self.device = device
|
||
|
||
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
||
|
||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch_dtype,
|
||
device_map=device,
|
||
trust_remote_code=True
|
||
)
|
||
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
model_name,
|
||
trust_remote_code=True
|
||
)
|
||
|
||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||
|
||
def extract_episode_segment(
|
||
self,
|
||
file_path: Path,
|
||
start_timestamp: float,
|
||
end_timestamp: float,
|
||
target_fps: int = 1
|
||
) -> Path:
|
||
"""
|
||
Extract a specific episode segment from concatenated video.
|
||
Uses minimal compression to preserve quality for local inference.
|
||
|
||
Args:
|
||
file_path: Path to the concatenated video file
|
||
start_timestamp: Starting timestamp in seconds (within this video file)
|
||
end_timestamp: Ending timestamp in seconds (within this video file)
|
||
target_fps: Target FPS (default: 1 for faster processing)
|
||
|
||
Returns:
|
||
Path to extracted video file
|
||
"""
|
||
import os
|
||
import tempfile
|
||
import subprocess
|
||
|
||
# Create temporary file for extracted video
|
||
tmp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||
tmp_path = Path(tmp_file.name)
|
||
tmp_file.close()
|
||
|
||
try:
|
||
# Check if ffmpeg is available
|
||
subprocess.run(['ffmpeg', '-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||
self.console.print("[yellow]Warning: ffmpeg not found, cannot extract episode segment[/yellow]")
|
||
return file_path
|
||
|
||
try:
|
||
# Calculate duration
|
||
duration = end_timestamp - start_timestamp
|
||
|
||
self.console.print(f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]")
|
||
|
||
# Use ffmpeg to extract segment with minimal quality loss
|
||
cmd = [
|
||
'ffmpeg',
|
||
'-i', str(file_path),
|
||
'-ss', str(start_timestamp),
|
||
'-t', str(duration),
|
||
'-r', str(target_fps),
|
||
'-c:v', 'libx264',
|
||
'-preset', 'ultrafast',
|
||
'-crf', '23',
|
||
'-an',
|
||
'-y',
|
||
str(tmp_path)
|
||
]
|
||
|
||
subprocess.run(
|
||
cmd,
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=subprocess.DEVNULL,
|
||
check=True
|
||
)
|
||
|
||
# Verify the output file was created and is not empty
|
||
if not tmp_path.exists() or tmp_path.stat().st_size == 0:
|
||
self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]")
|
||
if tmp_path.exists():
|
||
tmp_path.unlink()
|
||
raise RuntimeError("FFmpeg produced empty video file")
|
||
|
||
# Show extraction results
|
||
file_size_mb = tmp_path.stat().st_size / (1024 * 1024)
|
||
|
||
# Fail if file is too small (< 100KB likely means extraction failed)
|
||
if file_size_mb < 0.1:
|
||
self.console.print(f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]")
|
||
tmp_path.unlink()
|
||
raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
|
||
|
||
self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]")
|
||
|
||
return tmp_path
|
||
|
||
except subprocess.CalledProcessError as e:
|
||
self.console.print(f"[yellow]Warning: ffmpeg failed ({e})[/yellow]")
|
||
if tmp_path.exists():
|
||
tmp_path.unlink()
|
||
return file_path
|
||
|
||
def annotate(
|
||
self,
|
||
file_path: str | Path,
|
||
fps: int,
|
||
start_timestamp: float = 0.0,
|
||
end_timestamp: float | None = None,
|
||
max_retries: int = 3
|
||
) -> SubtaskAnnotation:
|
||
"""
|
||
Annotate a video file or episode segment using local GPU.
|
||
|
||
Args:
|
||
file_path: Path to the video file (may contain multiple concatenated episodes)
|
||
fps: Frames per second of the video
|
||
start_timestamp: Starting timestamp in seconds (within this video file)
|
||
end_timestamp: Ending timestamp in seconds (within this video file)
|
||
max_retries: Number of retries if annotation fails
|
||
|
||
Returns:
|
||
SubtaskAnnotation object with the results
|
||
"""
|
||
file_path = Path(file_path)
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"Video file not found: {file_path}")
|
||
|
||
# Calculate episode duration
|
||
if end_timestamp is None:
|
||
import cv2
|
||
import os
|
||
import sys
|
||
|
||
stderr_backup = sys.stderr
|
||
sys.stderr = open(os.devnull, 'w')
|
||
|
||
try:
|
||
cap = cv2.VideoCapture(str(file_path))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
||
end_timestamp = total_frames / video_fps if video_fps > 0 else 0
|
||
cap.release()
|
||
finally:
|
||
sys.stderr.close()
|
||
sys.stderr = stderr_backup
|
||
|
||
duration_seconds = end_timestamp - start_timestamp
|
||
|
||
duration_mins = int(duration_seconds // 60)
|
||
duration_secs = int(duration_seconds % 60)
|
||
duration_str = f"{duration_mins:02d}:{duration_secs:02d}"
|
||
|
||
self.console.print(f"[cyan]Processing episode from concatenated video: {file_path.name}[/cyan]")
|
||
self.console.print(f"[cyan]Episode timestamps: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration_seconds:.1f}s)[/cyan]")
|
||
self.console.print(f"[cyan]Episode duration: {duration_str}[/cyan]")
|
||
|
||
# Extract episode segment
|
||
extracted_path = self.extract_episode_segment(
|
||
file_path,
|
||
start_timestamp=start_timestamp,
|
||
end_timestamp=end_timestamp,
|
||
target_fps=1
|
||
)
|
||
is_extracted = extracted_path != file_path
|
||
|
||
try:
|
||
# Add video duration to prompt
|
||
prompt_with_duration = f"""{self.prompt}
|
||
|
||
CRITICAL - VIDEO DURATION:
|
||
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
|
||
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str} or very close to it."""
|
||
|
||
# Prepare messages for the model
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "video",
|
||
"video": str(extracted_path),
|
||
"fps": 1.0, # Sample at 1 FPS for analysis
|
||
},
|
||
{"type": "text", "text": prompt_with_duration},
|
||
],
|
||
}
|
||
]
|
||
|
||
# Generate annotation with retries
|
||
for attempt in range(max_retries):
|
||
try:
|
||
self.console.print(f"[cyan]Generating annotation (attempt {attempt + 1}/{max_retries})...[/cyan]")
|
||
|
||
# Prepare inputs
|
||
text = self.processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
|
||
inputs = self.processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
inputs = inputs.to(self.device)
|
||
|
||
# Generate
|
||
with torch.no_grad():
|
||
generated_ids = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=2048,
|
||
temperature=0.1, # Low temperature for consistent output
|
||
)
|
||
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||
]
|
||
|
||
response_text = self.processor.batch_decode(
|
||
generated_ids_trimmed,
|
||
skip_special_tokens=True,
|
||
clean_up_tokenization_spaces=False
|
||
)[0]
|
||
|
||
self.console.print(f"[dim]Raw response: {response_text[:200]}...[/dim]")
|
||
|
||
# Try to extract JSON from response
|
||
# Sometimes models wrap JSON in markdown code blocks
|
||
response_text = response_text.strip()
|
||
if "```json" in response_text:
|
||
response_text = response_text.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in response_text:
|
||
response_text = response_text.split("```")[1].split("```")[0].strip()
|
||
|
||
# Parse response
|
||
import json
|
||
try:
|
||
response_dict = json.loads(response_text)
|
||
annotation = SubtaskAnnotation.model_validate(response_dict)
|
||
except json.JSONDecodeError:
|
||
# Try to find JSON object in response
|
||
import re
|
||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||
if json_match:
|
||
response_dict = json.loads(json_match.group())
|
||
annotation = SubtaskAnnotation.model_validate(response_dict)
|
||
else:
|
||
raise ValueError("Could not parse JSON from model response")
|
||
|
||
self.console.print("[green]✓ Annotation completed successfully[/green]")
|
||
return annotation
|
||
|
||
except Exception as e:
|
||
self.console.print(f"[yellow]⚠ Attempt {attempt + 1} failed: {e}[/yellow]")
|
||
if attempt == max_retries - 1:
|
||
raise RuntimeError(f"Failed to annotate after {max_retries} attempts") from e
|
||
time.sleep(1)
|
||
|
||
finally:
|
||
# Clean up temporary extracted file
|
||
if is_extracted and extracted_path.exists():
|
||
extracted_path.unlink()
|
||
|
||
|
||
def display_annotation(annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int):
|
||
"""Display annotation in a nice tree format with frame indices"""
|
||
tree = Tree(f"[bold]Episode {episode_idx} - Subtask Annotation[/bold]")
|
||
|
||
# Subtasks
|
||
subtasks_branch = tree.add(f"[bold cyan]Subtasks ({len(annotation.subtasks)} total)[/bold cyan]")
|
||
for i, subtask in enumerate(annotation.subtasks, 1):
|
||
# Calculate frame indices for display
|
||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||
start_frame = int(start_sec * fps)
|
||
end_frame = int(end_sec * fps)
|
||
|
||
subtasks_branch.add(
|
||
f"{i}. [cyan]{subtask.name}[/cyan]: "
|
||
f"{subtask.timestamps.start} → {subtask.timestamps.end} "
|
||
f"[dim](frames {start_frame}-{end_frame})[/dim]"
|
||
)
|
||
|
||
console.print(tree)
|
||
|
||
|
||
def timestamp_to_seconds(timestamp: str) -> float:
|
||
"""Convert MM:SS or SS timestamp to seconds"""
|
||
parts = timestamp.split(":")
|
||
if len(parts) == 2:
|
||
return int(parts[0]) * 60 + int(parts[1])
|
||
else:
|
||
return int(parts[0])
|
||
|
||
|
||
def save_annotations_to_dataset(
|
||
dataset_path: Path,
|
||
annotations: dict[int, SubtaskAnnotation],
|
||
fps: int,
|
||
):
|
||
"""
|
||
Save annotations to LeRobot dataset parquet format.
|
||
|
||
For each episode, stores subtask annotations with:
|
||
- subtask_names: list of subtask names
|
||
- subtask_start_times: list of start times (seconds)
|
||
- subtask_end_times: list of end times (seconds)
|
||
- subtask_start_frames: list of start frames
|
||
- subtask_end_frames: list of end frames
|
||
"""
|
||
import pandas as pd
|
||
import pyarrow.parquet as pq
|
||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
|
||
|
||
console = Console()
|
||
|
||
# Load existing episodes metadata (returns datasets.Dataset)
|
||
episodes_dataset = load_episodes(dataset_path)
|
||
|
||
if episodes_dataset is None or len(episodes_dataset) == 0:
|
||
console.print("[red]Error: No episodes found in dataset[/red]")
|
||
return
|
||
|
||
# Convert to pandas DataFrame for easier manipulation
|
||
episodes_df = episodes_dataset.to_pandas()
|
||
|
||
# Add subtask columns to episodes dataframe
|
||
episodes_df["subtask_names"] = None
|
||
episodes_df["subtask_start_times"] = None
|
||
episodes_df["subtask_end_times"] = None
|
||
episodes_df["subtask_start_frames"] = None
|
||
episodes_df["subtask_end_frames"] = None
|
||
|
||
# Fill in annotations
|
||
for ep_idx, annotation in annotations.items():
|
||
if ep_idx >= len(episodes_df):
|
||
console.print(f"[yellow]Warning: Episode {ep_idx} not found in dataset[/yellow]")
|
||
continue
|
||
|
||
subtask_names = []
|
||
start_times = []
|
||
end_times = []
|
||
start_frames = []
|
||
end_frames = []
|
||
|
||
for subtask in annotation.subtasks:
|
||
subtask_names.append(subtask.name)
|
||
|
||
# Convert timestamps to seconds
|
||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||
start_times.append(start_sec)
|
||
end_times.append(end_sec)
|
||
|
||
# Calculate frame indices from timestamps and FPS
|
||
start_frame = int(start_sec * fps)
|
||
end_frame = int(end_sec * fps)
|
||
start_frames.append(start_frame)
|
||
end_frames.append(end_frame)
|
||
|
||
# Store as lists in the dataframe
|
||
episodes_df.at[ep_idx, "subtask_names"] = subtask_names
|
||
episodes_df.at[ep_idx, "subtask_start_times"] = start_times
|
||
episodes_df.at[ep_idx, "subtask_end_times"] = end_times
|
||
episodes_df.at[ep_idx, "subtask_start_frames"] = start_frames
|
||
episodes_df.at[ep_idx, "subtask_end_frames"] = end_frames
|
||
|
||
# Group episodes by their chunk and file indices
|
||
episodes_by_file = {}
|
||
for ep_idx in episodes_df.index:
|
||
chunk_idx = episodes_df.loc[ep_idx, "meta/episodes/chunk_index"]
|
||
file_idx = episodes_df.loc[ep_idx, "meta/episodes/file_index"]
|
||
key = (chunk_idx, file_idx)
|
||
|
||
if key not in episodes_by_file:
|
||
episodes_by_file[key] = []
|
||
episodes_by_file[key].append(ep_idx)
|
||
|
||
# Write back to parquet files
|
||
for (chunk_idx, file_idx), ep_indices in episodes_by_file.items():
|
||
episodes_path = dataset_path / DEFAULT_EPISODES_PATH.format(
|
||
chunk_index=chunk_idx, file_index=file_idx
|
||
)
|
||
|
||
if not episodes_path.exists():
|
||
console.print(f"[yellow]Warning: Episodes file not found: {episodes_path}[/yellow]")
|
||
continue
|
||
|
||
# Read the existing parquet file
|
||
file_df = pd.read_parquet(episodes_path)
|
||
|
||
# Add subtask columns if they don't exist
|
||
for col in ["subtask_names", "subtask_start_times", "subtask_end_times",
|
||
"subtask_start_frames", "subtask_end_frames"]:
|
||
if col not in file_df.columns:
|
||
file_df[col] = None
|
||
|
||
# Update rows that have annotations
|
||
for ep_idx in ep_indices:
|
||
if ep_idx in file_df.index and ep_idx in annotations:
|
||
file_df.at[ep_idx, "subtask_names"] = episodes_df.loc[ep_idx, "subtask_names"]
|
||
file_df.at[ep_idx, "subtask_start_times"] = episodes_df.loc[ep_idx, "subtask_start_times"]
|
||
file_df.at[ep_idx, "subtask_end_times"] = episodes_df.loc[ep_idx, "subtask_end_times"]
|
||
file_df.at[ep_idx, "subtask_start_frames"] = episodes_df.loc[ep_idx, "subtask_start_frames"]
|
||
file_df.at[ep_idx, "subtask_end_frames"] = episodes_df.loc[ep_idx, "subtask_end_frames"]
|
||
|
||
# Write back to parquet
|
||
file_df.to_parquet(episodes_path, engine="pyarrow", compression="snappy")
|
||
console.print(f"[green]✓ Updated {episodes_path.name} with {len([e for e in ep_indices if e in annotations])} annotations[/green]")
|
||
|
||
console.print(f"[bold green]✓ Saved {len(annotations)} episode annotations to parquet files[/bold green]")
|
||
|
||
|
||
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
|
||
"""
|
||
Load annotations from LeRobot dataset parquet files.
|
||
|
||
Reads subtask annotations from the episodes metadata parquet files.
|
||
"""
|
||
from lerobot.datasets.utils import load_episodes
|
||
|
||
episodes_dataset = load_episodes(dataset_path)
|
||
|
||
if episodes_dataset is None or len(episodes_dataset) == 0:
|
||
return {}
|
||
|
||
# Check if subtask columns exist
|
||
if "subtask_names" not in episodes_dataset.column_names:
|
||
return {}
|
||
|
||
# Convert to pandas DataFrame for easier access
|
||
episodes_df = episodes_dataset.to_pandas()
|
||
|
||
annotations = {}
|
||
|
||
for ep_idx in episodes_df.index:
|
||
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
|
||
|
||
# Skip episodes without annotations
|
||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||
continue
|
||
|
||
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
|
||
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
|
||
|
||
# Reconstruct SubtaskAnnotation from stored data
|
||
subtasks = []
|
||
for i, name in enumerate(subtask_names):
|
||
# Convert seconds back to MM:SS format
|
||
start_sec = int(start_times[i])
|
||
end_sec = int(end_times[i])
|
||
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
|
||
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
|
||
|
||
subtasks.append(
|
||
Subtask(
|
||
name=name,
|
||
timestamps=Timestamp(start=start_str, end=end_str)
|
||
)
|
||
)
|
||
|
||
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
|
||
|
||
return annotations
|
||
|
||
|
||
def process_single_episode(
|
||
ep_idx: int,
|
||
dataset_root: Path,
|
||
dataset_meta,
|
||
video_key: str,
|
||
fps: int,
|
||
annotator: VideoAnnotator,
|
||
console: Console,
|
||
) -> tuple[int, SubtaskAnnotation | None, str | None]:
|
||
"""
|
||
Process a single episode annotation.
|
||
|
||
Args:
|
||
ep_idx: Episode index
|
||
dataset_root: Dataset root path
|
||
dataset_meta: Dataset metadata
|
||
video_key: Video key to use
|
||
fps: FPS of the video
|
||
annotator: VideoAnnotator instance
|
||
console: Console for output
|
||
|
||
Returns:
|
||
Tuple of (episode_index, annotation or None, error message or None)
|
||
"""
|
||
try:
|
||
# Get video path
|
||
video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key)
|
||
|
||
if not video_path.exists():
|
||
return ep_idx, None, f"Video not found: {video_path}"
|
||
|
||
# Get video-specific timestamps (NOT global frame indices)
|
||
video_path_key = f"videos/{video_key}/from_timestamp"
|
||
video_path_key_to = f"videos/{video_key}/to_timestamp"
|
||
|
||
start_timestamp = float(dataset_meta.episodes[video_path_key][ep_idx])
|
||
end_timestamp = float(dataset_meta.episodes[video_path_key_to][ep_idx])
|
||
|
||
# Annotate with video-specific timestamps
|
||
annotation = annotator.annotate(
|
||
video_path,
|
||
fps,
|
||
start_timestamp=start_timestamp,
|
||
end_timestamp=end_timestamp
|
||
)
|
||
|
||
return ep_idx, annotation, None
|
||
|
||
except Exception as e:
|
||
return ep_idx, None, str(e)
|
||
|
||
|
||
def worker_process_episodes(
|
||
worker_id: int,
|
||
gpu_id: int,
|
||
episode_indices: list[int],
|
||
repo_id: str,
|
||
video_key: str,
|
||
subtask_list: list[str],
|
||
model_name: str,
|
||
torch_dtype: torch.dtype,
|
||
) -> dict[int, SubtaskAnnotation]:
|
||
"""
|
||
Worker function for parallel processing across GPUs.
|
||
|
||
Args:
|
||
worker_id: Worker ID for logging
|
||
gpu_id: GPU device ID to use
|
||
episode_indices: List of episode indices to process
|
||
repo_id: Dataset repo ID
|
||
video_key: Video key to use
|
||
subtask_list: List of subtask names
|
||
model_name: Model name to load
|
||
torch_dtype: Model dtype
|
||
|
||
Returns:
|
||
Dictionary of episode_idx -> SubtaskAnnotation
|
||
"""
|
||
# Set GPU device
|
||
device = f"cuda:{gpu_id}"
|
||
|
||
# Initialize console for this worker
|
||
console = Console()
|
||
console.print(f"[cyan]Worker {worker_id} starting on GPU {gpu_id} with {len(episode_indices)} episodes[/cyan]")
|
||
|
||
# Load dataset (this is lightweight, just metadata)
|
||
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||
fps = dataset.fps
|
||
|
||
# Initialize annotator for this worker
|
||
annotator = VideoAnnotator(
|
||
subtask_list=subtask_list,
|
||
model_name=model_name,
|
||
device=device,
|
||
torch_dtype=torch_dtype
|
||
)
|
||
|
||
# Process assigned episodes
|
||
annotations = {}
|
||
|
||
for i, ep_idx in enumerate(episode_indices):
|
||
console.print(f"[cyan]Worker {worker_id} | Episode {ep_idx} ({i+1}/{len(episode_indices)})[/cyan]")
|
||
|
||
result_ep_idx, annotation, error = process_single_episode(
|
||
ep_idx,
|
||
dataset.root,
|
||
dataset.meta,
|
||
video_key,
|
||
fps,
|
||
annotator,
|
||
console
|
||
)
|
||
|
||
if error:
|
||
console.print(f"[red]Worker {worker_id} | ✗ Failed episode {result_ep_idx}: {error}[/red]")
|
||
elif annotation:
|
||
annotations[result_ep_idx] = annotation
|
||
console.print(f"[green]Worker {worker_id} | ✓ Completed episode {result_ep_idx}[/green]")
|
||
|
||
console.print(f"[bold green]Worker {worker_id} completed {len(annotations)}/{len(episode_indices)} episodes[/bold green]")
|
||
return annotations
|
||
|
||
|
||
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
||
"""
|
||
Compute average temporal proportion for each subtask across all episodes.
|
||
This is the key insight from SARM - use semantic subtasks instead of frame indices.
|
||
"""
|
||
# Collect all proportions per subtask
|
||
subtask_proportions = {}
|
||
|
||
for annotation in annotations.values():
|
||
# Calculate total episode duration
|
||
total_duration = 0
|
||
durations = {}
|
||
|
||
for subtask in annotation.subtasks:
|
||
# Parse timestamps
|
||
start_parts = subtask.timestamps.start.split(":")
|
||
end_parts = subtask.timestamps.end.split(":")
|
||
|
||
if len(start_parts) == 2:
|
||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
|
||
else:
|
||
start_seconds = int(start_parts[0])
|
||
|
||
if len(end_parts) == 2:
|
||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
|
||
else:
|
||
end_seconds = int(end_parts[0])
|
||
|
||
duration = end_seconds - start_seconds
|
||
durations[subtask.name] = duration
|
||
total_duration += duration
|
||
|
||
# Calculate proportions for this episode
|
||
if total_duration > 0:
|
||
for name, duration in durations.items():
|
||
if name not in subtask_proportions:
|
||
subtask_proportions[name] = []
|
||
subtask_proportions[name].append(duration / total_duration)
|
||
|
||
# Average across episodes
|
||
avg_proportions = {
|
||
name: sum(props) / len(props)
|
||
for name, props in subtask_proportions.items()
|
||
}
|
||
|
||
# Normalize to sum to 1.0
|
||
total = sum(avg_proportions.values())
|
||
if total > 0:
|
||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||
|
||
return avg_proportions
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
Examples:
|
||
# Sequential processing (single GPU):
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.top \\
|
||
--push-to-hub
|
||
|
||
# Parallel processing with 4 GPUs (4x speedup):
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.top \\
|
||
--num-workers 4 \\
|
||
--push-to-hub
|
||
|
||
# Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3):
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp,lift,place" \\
|
||
--video-key observation.images.top \\
|
||
--num-workers 3 \\
|
||
--gpu-ids 0 2 3 \\
|
||
--push-to-hub
|
||
|
||
# List available cameras:
|
||
python subtask_annotation.py \\
|
||
--repo-id pepijn223/mydataset \\
|
||
--subtasks "reach,grasp" \\
|
||
--max-episodes 0
|
||
|
||
Performance Tips:
|
||
- Each worker loads one model instance on its assigned GPU
|
||
- The 30B model requires ~60GB VRAM per GPU
|
||
- Use --num-workers N for N GPUs to get N× speedup
|
||
- Episodes are distributed round-robin across workers
|
||
"""
|
||
)
|
||
parser.add_argument(
|
||
"--repo-id",
|
||
type=str,
|
||
required=True,
|
||
help="HuggingFace dataset repository ID (e.g., 'pepijn223/mydataset')",
|
||
)
|
||
parser.add_argument(
|
||
"--subtasks",
|
||
type=str,
|
||
required=True,
|
||
help="Comma-separated list of subtask names (e.g., 'reach,grasp,lift,place')",
|
||
)
|
||
parser.add_argument(
|
||
"--episodes",
|
||
type=int,
|
||
nargs="+",
|
||
default=None,
|
||
help="Specific episode indices to annotate (default: all episodes)",
|
||
)
|
||
parser.add_argument(
|
||
"--max-episodes",
|
||
type=int,
|
||
default=None,
|
||
help="Maximum number of episodes to annotate",
|
||
)
|
||
parser.add_argument(
|
||
"--model",
|
||
type=str,
|
||
default="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||
help="Qwen3-VL model to use (default: Qwen/Qwen3-VL-30B-A3B-Instruct)",
|
||
)
|
||
parser.add_argument(
|
||
"--skip-existing",
|
||
action="store_true",
|
||
help="Skip episodes that already have annotations",
|
||
)
|
||
parser.add_argument(
|
||
"--video-key",
|
||
type=str,
|
||
default=None,
|
||
help="Camera/video key to use for annotation (e.g., 'observation.images.top'). "
|
||
"If not specified, uses the first available video key.",
|
||
)
|
||
parser.add_argument(
|
||
"--push-to-hub",
|
||
action="store_true",
|
||
help="Push annotated dataset to HuggingFace Hub",
|
||
)
|
||
parser.add_argument(
|
||
"--output-repo-id",
|
||
type=str,
|
||
default=None,
|
||
help="Output repository ID for push (default: same as --repo-id)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cuda",
|
||
help="Device to use (cuda, cpu)",
|
||
)
|
||
parser.add_argument(
|
||
"--dtype",
|
||
type=str,
|
||
default="bfloat16",
|
||
choices=["bfloat16", "float16", "float32"],
|
||
help="Model dtype (default: bfloat16)",
|
||
)
|
||
parser.add_argument(
|
||
"--num-workers",
|
||
type=int,
|
||
default=1,
|
||
help="Number of parallel workers for multi-GPU processing (default: 1 for sequential). "
|
||
"Set to number of GPUs available for parallel processing.",
|
||
)
|
||
parser.add_argument(
|
||
"--gpu-ids",
|
||
type=int,
|
||
nargs="+",
|
||
default=None,
|
||
help="Specific GPU IDs to use (e.g., --gpu-ids 0 1 2). "
|
||
"If not specified, uses GPUs 0 to num-workers-1.",
|
||
)
|
||
parser.add_argument(
|
||
"--batch-size",
|
||
type=int,
|
||
default=1,
|
||
help="Batch size for processing multiple episodes per inference (experimental, default: 1)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Parse subtask list
|
||
subtask_list = [s.strip() for s in args.subtasks.split(",")]
|
||
|
||
# Parse dtype
|
||
dtype_map = {
|
||
"bfloat16": torch.bfloat16,
|
||
"float16": torch.float16,
|
||
"float32": torch.float32,
|
||
}
|
||
torch_dtype = dtype_map[args.dtype]
|
||
|
||
console = Console()
|
||
console.print(Panel.fit(
|
||
"[bold cyan]SARM Subtask Annotation (Local GPU)[/bold cyan]\n"
|
||
f"Dataset: {args.repo_id}\n"
|
||
f"Model: {args.model}\n"
|
||
f"Device: {args.device}\n"
|
||
f"Subtasks: {', '.join(subtask_list)}",
|
||
border_style="cyan"
|
||
))
|
||
|
||
# Load dataset
|
||
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||
|
||
# Get FPS from dataset
|
||
fps = dataset.fps
|
||
console.print(f"[cyan]Dataset FPS: {fps}[/cyan]")
|
||
|
||
# Display available cameras/video keys
|
||
if len(dataset.meta.video_keys) == 0:
|
||
console.print("[red]Error: No video keys found in dataset[/red]")
|
||
return
|
||
|
||
console.print(f"\n[cyan]Available cameras/video keys:[/cyan]")
|
||
for i, vk in enumerate(dataset.meta.video_keys, 1):
|
||
console.print(f" {i}. {vk}")
|
||
|
||
# Get video key
|
||
if args.video_key:
|
||
if args.video_key not in dataset.meta.video_keys:
|
||
console.print(f"[red]Error: Video key '{args.video_key}' not found in dataset[/red]")
|
||
console.print(f"[yellow]Available keys: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
||
return
|
||
video_key = args.video_key
|
||
console.print(f"\n[green]Using specified camera: {video_key}[/green]")
|
||
else:
|
||
video_key = dataset.meta.video_keys[0]
|
||
console.print(f"\n[yellow]No camera specified, using first available: {video_key}[/yellow]")
|
||
if len(dataset.meta.video_keys) > 1:
|
||
console.print(f"[yellow]Tip: Use --video-key to specify a different camera[/yellow]")
|
||
|
||
# Determine episodes to annotate
|
||
if args.episodes:
|
||
episode_indices = args.episodes
|
||
else:
|
||
episode_indices = list(range(dataset.meta.total_episodes))
|
||
if args.max_episodes:
|
||
episode_indices = episode_indices[: args.max_episodes]
|
||
|
||
console.print(f"[cyan]Will annotate {len(episode_indices)} episodes[/cyan]")
|
||
|
||
# Load existing annotations
|
||
existing_annotations = load_annotations_from_dataset(dataset.root)
|
||
|
||
if args.skip_existing and existing_annotations:
|
||
console.print(f"[yellow]Found {len(existing_annotations)} existing annotations[/yellow]")
|
||
episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
|
||
console.print(f"[cyan]Will annotate {len(episode_indices)} new episodes[/cyan]")
|
||
|
||
if not episode_indices:
|
||
console.print("[green]All episodes already annotated![/green]")
|
||
return
|
||
|
||
# Determine GPU IDs to use
|
||
if args.gpu_ids:
|
||
gpu_ids = args.gpu_ids
|
||
if len(gpu_ids) < args.num_workers:
|
||
console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {len(gpu_ids)} GPU IDs provided[/yellow]")
|
||
args.num_workers = len(gpu_ids)
|
||
else:
|
||
# Check available GPUs
|
||
if torch.cuda.is_available():
|
||
num_gpus = torch.cuda.device_count()
|
||
if args.num_workers > num_gpus:
|
||
console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {num_gpus} GPUs available[/yellow]")
|
||
args.num_workers = min(args.num_workers, num_gpus)
|
||
gpu_ids = list(range(args.num_workers))
|
||
else:
|
||
console.print("[yellow]Warning: CUDA not available, using CPU (num_workers will be ignored)[/yellow]")
|
||
args.num_workers = 1
|
||
gpu_ids = [0] # Dummy value for CPU
|
||
|
||
# Annotate episodes - choose sequential or parallel mode
|
||
annotations = existing_annotations.copy()
|
||
|
||
if args.num_workers > 1:
|
||
# ===== PARALLEL PROCESSING MODE =====
|
||
console.print(f"\n[bold cyan]Using {args.num_workers} parallel workers on GPUs: {gpu_ids}[/bold cyan]")
|
||
|
||
# Split episodes across workers
|
||
episodes_per_worker = [[] for _ in range(args.num_workers)]
|
||
for i, ep_idx in enumerate(episode_indices):
|
||
worker_idx = i % args.num_workers
|
||
episodes_per_worker[worker_idx].append(ep_idx)
|
||
|
||
# Show distribution
|
||
for worker_id, episodes in enumerate(episodes_per_worker):
|
||
console.print(f"[cyan]Worker {worker_id} (GPU {gpu_ids[worker_id]}): {len(episodes)} episodes[/cyan]")
|
||
|
||
# Start parallel processing using ProcessPoolExecutor
|
||
console.print(f"\n[bold cyan]Starting parallel annotation...[/bold cyan]")
|
||
|
||
# Use 'spawn' method for CUDA compatibility (required for multi-GPU)
|
||
mp_context = mp.get_context('spawn')
|
||
|
||
with ProcessPoolExecutor(max_workers=args.num_workers, mp_context=mp_context) as executor:
|
||
# Submit all worker jobs
|
||
futures = []
|
||
for worker_id in range(args.num_workers):
|
||
if not episodes_per_worker[worker_id]:
|
||
continue # Skip workers with no episodes
|
||
|
||
future = executor.submit(
|
||
worker_process_episodes,
|
||
worker_id,
|
||
gpu_ids[worker_id],
|
||
episodes_per_worker[worker_id],
|
||
args.repo_id,
|
||
video_key,
|
||
subtask_list,
|
||
args.model,
|
||
torch_dtype,
|
||
)
|
||
futures.append(future)
|
||
|
||
# Collect results as they complete
|
||
for future in as_completed(futures):
|
||
try:
|
||
worker_annotations = future.result()
|
||
annotations.update(worker_annotations)
|
||
|
||
# Save after each worker completes
|
||
save_annotations_to_dataset(dataset.root, annotations, fps)
|
||
console.print(f"[green]✓ Worker completed, saved {len(worker_annotations)} annotations[/green]")
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]✗ Worker failed: {e}[/red]")
|
||
|
||
console.print(f"\n[bold green]Parallel processing complete! Annotated {len(annotations)} episodes[/bold green]")
|
||
|
||
# Display all annotations
|
||
for ep_idx in sorted(annotations.keys()):
|
||
if ep_idx not in existing_annotations: # Only show newly annotated
|
||
display_annotation(annotations[ep_idx], console, ep_idx, fps)
|
||
|
||
else:
|
||
# ===== SEQUENTIAL PROCESSING MODE =====
|
||
console.print(f"\n[bold cyan]Using sequential processing (single GPU/CPU)[/bold cyan]")
|
||
|
||
# Initialize annotator with subtask list
|
||
annotator = VideoAnnotator(
|
||
subtask_list=subtask_list,
|
||
model_name=args.model,
|
||
device=args.device,
|
||
torch_dtype=torch_dtype
|
||
)
|
||
|
||
# Process episodes sequentially
|
||
for i, ep_idx in enumerate(episode_indices):
|
||
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]")
|
||
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]")
|
||
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
|
||
|
||
result_ep_idx, annotation, error = process_single_episode(
|
||
ep_idx,
|
||
dataset.root,
|
||
dataset.meta,
|
||
video_key,
|
||
fps,
|
||
annotator,
|
||
console
|
||
)
|
||
|
||
if error:
|
||
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]")
|
||
continue
|
||
elif annotation:
|
||
annotations[result_ep_idx] = annotation
|
||
display_annotation(annotation, console, result_ep_idx, fps)
|
||
save_annotations_to_dataset(dataset.root, annotations, fps)
|
||
|
||
# Compute temporal proportions (key SARM insight)
|
||
console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]")
|
||
temporal_proportions = compute_temporal_proportions(annotations, fps)
|
||
|
||
# Save temporal proportions
|
||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||
proportions_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(proportions_path, "w") as f:
|
||
json.dump(temporal_proportions, f, indent=2)
|
||
|
||
console.print(f"[green]✓ Saved temporal proportions to {proportions_path}[/green]")
|
||
console.print("\n[cyan]Average temporal proportions:[/cyan]")
|
||
for name, proportion in sorted(temporal_proportions.items(), key=lambda x: -x[1]):
|
||
console.print(f" {name}: {proportion:.1%}")
|
||
|
||
# Create summary
|
||
console.print(f"\n[bold green]{'=' * 60}[/bold green]")
|
||
console.print(f"[bold green]Annotation Complete![/bold green]")
|
||
console.print(f"[bold green]{'=' * 60}[/bold green]")
|
||
console.print(f"Total episodes annotated: {len(annotations)}")
|
||
console.print(f"Total subtasks found: {sum(len(ann.subtasks) for ann in annotations.values())}")
|
||
|
||
# Push to hub if requested
|
||
if args.push_to_hub:
|
||
output_repo = args.output_repo_id if args.output_repo_id else args.repo_id
|
||
console.print(f"\n[bold cyan]Pushing to HuggingFace Hub: {output_repo}[/bold cyan]")
|
||
|
||
try:
|
||
dataset.push_to_hub(push_videos=True)
|
||
console.print(f"[bold green]✓ Successfully pushed to {output_repo}[/bold green]")
|
||
except Exception as e:
|
||
console.print(f"[red]✗ Failed to push to hub: {e}[/red]")
|
||
console.print("[yellow]Annotations are still saved locally[/yellow]")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|