diff --git a/examples/dataset_annotation/visualize_subtask_annotations.py b/examples/dataset_annotation/visualize_subtask_annotations.py new file mode 100644 index 000000000..4c349372a --- /dev/null +++ b/examples/dataset_annotation/visualize_subtask_annotations.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visualize SARM Subtask Annotations + +This script creates visualizations of the subtask annotations generated by subtask_annotation.py. +For each episode, it shows: +- A timeline with dashed vertical lines at subtask boundaries +- Sample frames from the episode at key points (start, middle, end of each subtask) +- Color-coded subtask segments + +Usage: + python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5 +""" + +import argparse +import random +from pathlib import Path + +import cv2 +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np +import pandas as pd +from matplotlib.lines import Line2D +from rich.console import Console + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import load_episodes +from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp + + +def timestamp_to_seconds(timestamp: str) -> float: + """Convert MM:SS or SS timestamp to seconds""" + parts = timestamp.split(":") + if len(parts) == 2: + return int(parts[0]) * 60 + int(parts[1]) + else: + return int(parts[0]) + + +def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]: + """ + Load annotations from LeRobot dataset parquet files. + + Reads subtask annotations from the episodes metadata parquet files. + """ + episodes_dataset = load_episodes(dataset_path) + + if episodes_dataset is None or len(episodes_dataset) == 0: + return {} + + # Check if subtask columns exist + if "subtask_names" not in episodes_dataset.column_names: + return {} + + # Convert to pandas DataFrame for easier access + episodes_df = episodes_dataset.to_pandas() + + annotations = {} + + for ep_idx in episodes_df.index: + subtask_names = episodes_df.loc[ep_idx, "subtask_names"] + + # Skip episodes without annotations + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + continue + + start_times = episodes_df.loc[ep_idx, "subtask_start_times"] + end_times = episodes_df.loc[ep_idx, "subtask_end_times"] + + # Reconstruct SubtaskAnnotation from stored data + subtasks = [] + for i, name in enumerate(subtask_names): + # Convert seconds back to MM:SS format + start_sec = int(start_times[i]) + end_sec = int(end_times[i]) + start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}" + end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}" + + subtasks.append( + Subtask( + name=name, + timestamps=Timestamp(start=start_str, end=end_str) + ) + ) + + annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks) + + return annotations + + +# Color palette for subtasks (colorblind-friendly) +SUBTASK_COLORS = [ + "#E69F00", # Orange + "#56B4E9", # Sky blue + "#009E73", # Bluish green + "#F0E442", # Yellow + "#0072B2", # Blue + "#D55E00", # Vermillion + "#CC79A7", # Reddish purple + "#999999", # Gray +] + + +def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None: + """Extract a single frame from video at given timestamp.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return None + + # Set position to timestamp + cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + ret, frame = cap.read() + cap.release() + + if ret: + # Convert BGR to RGB + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + return None + + +def visualize_episode( + episode_idx: int, + annotation, + video_path: Path, + video_start_timestamp: float, + video_end_timestamp: float, + fps: int, + output_path: Path, + video_key: str, +): + """ + Create visualization for a single episode. + + Shows: + - Top row: Sample frames from the episode (one per subtask) + - Bottom: Timeline with subtask segments and boundary lines + """ + subtasks = annotation.subtasks + num_subtasks = len(subtasks) + + if num_subtasks == 0: + print(f"No subtasks found for episode {episode_idx}") + return + + # Calculate episode duration + episode_duration = video_end_timestamp - video_start_timestamp + + # Extract sample frames - get frame from middle of each subtask + sample_frames = [] + frame_timestamps = [] + + for subtask in subtasks: + start_sec = timestamp_to_seconds(subtask.timestamps.start) + end_sec = timestamp_to_seconds(subtask.timestamps.end) + mid_sec = (start_sec + end_sec) / 2 + + # Convert to video timestamp (add video_start_timestamp offset) + video_timestamp = video_start_timestamp + mid_sec + frame_timestamps.append(mid_sec) + + frame = extract_frame_from_video(video_path, video_timestamp) + sample_frames.append(frame) + + # Create figure + fig = plt.figure(figsize=(16, 10)) + + # Use a dark background for better contrast + fig.patch.set_facecolor('#1a1a2e') + + # Calculate grid layout + # Top section: frames (variable number of columns based on subtasks) + # Bottom section: timeline + + # Create gridspec + gs = fig.add_gridspec( + 2, max(num_subtasks, 1), + height_ratios=[2, 1], + hspace=0.3, + wspace=0.1, + left=0.05, right=0.95, + top=0.88, bottom=0.1 + ) + + # Add title + fig.suptitle( + f"Episode {episode_idx} - Subtask Annotations", + fontsize=18, + fontweight='bold', + color='white', + y=0.96 + ) + + # Add subtitle with video info + fig.text( + 0.5, 0.91, + f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks", + ha='center', + fontsize=11, + color='#888888' + ) + + # Plot sample frames + for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)): + ax = fig.add_subplot(gs[0, i]) + ax.set_facecolor('#16213e') + + if frame is not None: + ax.imshow(frame) + else: + ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center', + fontsize=12, color='white', transform=ax.transAxes) + + ax.set_title( + f"{subtask.name}", + fontsize=10, + fontweight='bold', + color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)], + pad=8 + ) + ax.axis('off') + + # Add frame timestamp below + ax.text( + 0.5, -0.08, + f"t={frame_timestamps[i]:.1f}s", + ha='center', + fontsize=9, + color='#888888', + transform=ax.transAxes + ) + + # Create timeline subplot spanning all columns + ax_timeline = fig.add_subplot(gs[1, :]) + ax_timeline.set_facecolor('#16213e') + + # Get total duration from last subtask end time + total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end) + + # Draw subtask segments as colored bars + bar_height = 0.6 + bar_y = 0.5 + + for i, subtask in enumerate(subtasks): + start_sec = timestamp_to_seconds(subtask.timestamps.start) + end_sec = timestamp_to_seconds(subtask.timestamps.end) + color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)] + + # Draw segment bar + rect = mpatches.FancyBboxPatch( + (start_sec, bar_y - bar_height/2), + end_sec - start_sec, + bar_height, + boxstyle="round,pad=0.02,rounding_size=0.1", + facecolor=color, + edgecolor='white', + linewidth=1.5, + alpha=0.85 + ) + ax_timeline.add_patch(rect) + + # Add subtask label inside bar + mid_x = (start_sec + end_sec) / 2 + duration = end_sec - start_sec + + # Only add text if segment is wide enough + if duration > total_duration * 0.08: + ax_timeline.text( + mid_x, bar_y, + subtask.name, + ha='center', va='center', + fontsize=9, + fontweight='bold', + color='black' if i in [3] else 'white', # Yellow needs dark text + rotation=0 if duration > total_duration * 0.15 else 45 + ) + + # Draw boundary lines (dashed vertical lines between subtasks) + boundary_times = [] + for i, subtask in enumerate(subtasks): + start_sec = timestamp_to_seconds(subtask.timestamps.start) + end_sec = timestamp_to_seconds(subtask.timestamps.end) + + # Add start boundary (except for first subtask at t=0) + if i == 0 and start_sec > 0: + boundary_times.append(start_sec) + elif i > 0: + boundary_times.append(start_sec) + + # Add end boundary for last subtask + if i == len(subtasks) - 1: + boundary_times.append(end_sec) + + # Draw dashed lines at boundaries + for t in boundary_times: + ax_timeline.axvline( + x=t, + ymin=0.1, ymax=0.9, + color='white', + linestyle='--', + linewidth=2, + alpha=0.9 + ) + + # Add time label below line + ax_timeline.text( + t, 0.0, + f"{int(t//60):02d}:{int(t%60):02d}", + ha='center', va='top', + fontsize=8, + color='#cccccc' + ) + + # Add start line at t=0 + ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9) + ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold') + + # Configure timeline axes + ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02) + ax_timeline.set_ylim(-0.3, 1.2) + ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10) + ax_timeline.set_ylabel("") + + # Style the axes + ax_timeline.spines['top'].set_visible(False) + ax_timeline.spines['right'].set_visible(False) + ax_timeline.spines['left'].set_visible(False) + ax_timeline.spines['bottom'].set_color('#444444') + ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9) + ax_timeline.tick_params(axis='y', left=False, labelleft=False) + + # Add x-axis ticks at regular intervals + tick_interval = max(1, int(total_duration / 10)) + ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval)) + + # Add legend explaining line styles + legend_elements = [ + Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'), + Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'), + ] + ax_timeline.legend( + handles=legend_elements, + loc='upper right', + framealpha=0.3, + facecolor='#16213e', + edgecolor='#444444', + fontsize=9, + labelcolor='white' + ) + + # Save figure + plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight') + plt.close() + + return output_path + + +def main(): + parser = argparse.ArgumentParser( + description="Visualize SARM subtask annotations", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="HuggingFace dataset repository ID", + ) + parser.add_argument( + "--num-episodes", + type=int, + default=5, + help="Number of random episodes to visualize (default: 5)", + ) + parser.add_argument( + "--episodes", + type=int, + nargs="+", + default=None, + help="Specific episode indices to visualize (overrides --num-episodes)", + ) + parser.add_argument( + "--video-key", + type=str, + default=None, + help="Camera/video key to use. If not specified, uses first available.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./subtask_viz", + help="Output directory for visualizations (default: ./subtask_viz)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility", + ) + + args = parser.parse_args() + + console = Console() + + # Set random seed if specified + if args.seed is not None: + random.seed(args.seed) + + console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]") + dataset = LeRobotDataset(args.repo_id, download_videos=True) + fps = dataset.fps + + # Get video key + if args.video_key: + if args.video_key not in dataset.meta.video_keys: + console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]") + console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]") + return + video_key = args.video_key + else: + video_key = dataset.meta.video_keys[0] + + console.print(f"[cyan]Using camera: {video_key}[/cyan]") + console.print(f"[cyan]FPS: {fps}[/cyan]") + + # Load annotations + console.print(f"\n[cyan]Loading annotations...[/cyan]") + annotations = load_annotations_from_dataset(dataset.root) + + if not annotations: + console.print("[red]Error: No annotations found in dataset[/red]") + console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]") + return + + console.print(f"[green]Found {len(annotations)} annotated episodes[/green]") + + # Determine which episodes to visualize + if args.episodes: + episode_indices = args.episodes + # Validate episodes exist + for ep in episode_indices: + if ep not in annotations: + console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]") + episode_indices = [ep for ep in episode_indices if ep in annotations] + else: + # Random selection + available_episodes = list(annotations.keys()) + num_to_select = min(args.num_episodes, len(available_episodes)) + episode_indices = random.sample(available_episodes, num_to_select) + episode_indices.sort() + + if not episode_indices: + console.print("[red]Error: No valid episodes to visualize[/red]") + return + + console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]") + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate visualizations + for ep_idx in episode_indices: + console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]") + + annotation = annotations[ep_idx] + + # Get video path and timestamps + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) + + if not video_path.exists(): + console.print(f"[red]Video not found: {video_path}[/red]") + continue + + # Get episode-specific timestamps within the video file + video_path_key = f"videos/{video_key}/from_timestamp" + video_path_key_to = f"videos/{video_key}/to_timestamp" + + video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx]) + video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx]) + + # Create visualization + output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png" + + try: + visualize_episode( + episode_idx=ep_idx, + annotation=annotation, + video_path=video_path, + video_start_timestamp=video_start_timestamp, + video_end_timestamp=video_end_timestamp, + fps=fps, + output_path=output_path, + video_key=video_key, + ) + console.print(f"[green]✓ Saved: {output_path}[/green]") + except Exception as e: + console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]") + + # Print summary + console.print(f"\n[bold green]{'=' * 50}[/bold green]") + console.print(f"[bold green]Visualization Complete![/bold green]") + console.print(f"[bold green]{'=' * 50}[/bold green]") + console.print(f"Output directory: {output_dir.absolute()}") + console.print(f"Episodes visualized: {len(episode_indices)}") + + +if __name__ == "__main__": + main() +