add visualize subtask annotations

This commit is contained in:
Pepijn
2025-11-28 16:59:29 +01:00
parent fa5004bd8c
commit 6e3b972534
@@ -0,0 +1,525 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize SARM Subtask Annotations
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
For each episode, it shows:
- A timeline with dashed vertical lines at subtask boundaries
- Sample frames from the episode at key points (start, middle, end of each subtask)
- Color-coded subtask segments
Usage:
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
"""
import argparse
import random
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from rich.console import Console
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import load_episodes
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
"""
Load annotations from LeRobot dataset parquet files.
Reads subtask annotations from the episodes metadata parquet files.
"""
episodes_dataset = load_episodes(dataset_path)
if episodes_dataset is None or len(episodes_dataset) == 0:
return {}
# Check if subtask columns exist
if "subtask_names" not in episodes_dataset.column_names:
return {}
# Convert to pandas DataFrame for easier access
episodes_df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
# Reconstruct SubtaskAnnotation from stored data
subtasks = []
for i, name in enumerate(subtask_names):
# Convert seconds back to MM:SS format
start_sec = int(start_times[i])
end_sec = int(end_times[i])
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
subtasks.append(
Subtask(
name=name,
timestamps=Timestamp(start=start_str, end=end_str)
)
)
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
return annotations
# Color palette for subtasks (colorblind-friendly)
SUBTASK_COLORS = [
"#E69F00", # Orange
"#56B4E9", # Sky blue
"#009E73", # Bluish green
"#F0E442", # Yellow
"#0072B2", # Blue
"#D55E00", # Vermillion
"#CC79A7", # Reddish purple
"#999999", # Gray
]
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
# Set position to timestamp
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def visualize_episode(
episode_idx: int,
annotation,
video_path: Path,
video_start_timestamp: float,
video_end_timestamp: float,
fps: int,
output_path: Path,
video_key: str,
):
"""
Create visualization for a single episode.
Shows:
- Top row: Sample frames from the episode (one per subtask)
- Bottom: Timeline with subtask segments and boundary lines
"""
subtasks = annotation.subtasks
num_subtasks = len(subtasks)
if num_subtasks == 0:
print(f"No subtasks found for episode {episode_idx}")
return
# Calculate episode duration
episode_duration = video_end_timestamp - video_start_timestamp
# Extract sample frames - get frame from middle of each subtask
sample_frames = []
frame_timestamps = []
for subtask in subtasks:
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
mid_sec = (start_sec + end_sec) / 2
# Convert to video timestamp (add video_start_timestamp offset)
video_timestamp = video_start_timestamp + mid_sec
frame_timestamps.append(mid_sec)
frame = extract_frame_from_video(video_path, video_timestamp)
sample_frames.append(frame)
# Create figure
fig = plt.figure(figsize=(16, 10))
# Use a dark background for better contrast
fig.patch.set_facecolor('#1a1a2e')
# Calculate grid layout
# Top section: frames (variable number of columns based on subtasks)
# Bottom section: timeline
# Create gridspec
gs = fig.add_gridspec(
2, max(num_subtasks, 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05, right=0.95,
top=0.88, bottom=0.1
)
# Add title
fig.suptitle(
f"Episode {episode_idx} - Subtask Annotations",
fontsize=18,
fontweight='bold',
color='white',
y=0.96
)
# Add subtitle with video info
fig.text(
0.5, 0.91,
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
ha='center',
fontsize=11,
color='#888888'
)
# Plot sample frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor('#16213e')
if frame is not None:
ax.imshow(frame)
else:
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
fontsize=12, color='white', transform=ax.transAxes)
ax.set_title(
f"{subtask.name}",
fontsize=10,
fontweight='bold',
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
pad=8
)
ax.axis('off')
# Add frame timestamp below
ax.text(
0.5, -0.08,
f"t={frame_timestamps[i]:.1f}s",
ha='center',
fontsize=9,
color='#888888',
transform=ax.transAxes
)
# Create timeline subplot spanning all columns
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor('#16213e')
# Get total duration from last subtask end time
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Draw subtask segments as colored bars
bar_height = 0.6
bar_y = 0.5
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
# Draw segment bar
rect = mpatches.FancyBboxPatch(
(start_sec, bar_y - bar_height/2),
end_sec - start_sec,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor='white',
linewidth=1.5,
alpha=0.85
)
ax_timeline.add_patch(rect)
# Add subtask label inside bar
mid_x = (start_sec + end_sec) / 2
duration = end_sec - start_sec
# Only add text if segment is wide enough
if duration > total_duration * 0.08:
ax_timeline.text(
mid_x, bar_y,
subtask.name,
ha='center', va='center',
fontsize=9,
fontweight='bold',
color='black' if i in [3] else 'white', # Yellow needs dark text
rotation=0 if duration > total_duration * 0.15 else 45
)
# Draw boundary lines (dashed vertical lines between subtasks)
boundary_times = []
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
# Add start boundary (except for first subtask at t=0)
if i == 0 and start_sec > 0:
boundary_times.append(start_sec)
elif i > 0:
boundary_times.append(start_sec)
# Add end boundary for last subtask
if i == len(subtasks) - 1:
boundary_times.append(end_sec)
# Draw dashed lines at boundaries
for t in boundary_times:
ax_timeline.axvline(
x=t,
ymin=0.1, ymax=0.9,
color='white',
linestyle='--',
linewidth=2,
alpha=0.9
)
# Add time label below line
ax_timeline.text(
t, 0.0,
f"{int(t//60):02d}:{int(t%60):02d}",
ha='center', va='top',
fontsize=8,
color='#cccccc'
)
# Add start line at t=0
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
# Configure timeline axes
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax_timeline.set_ylim(-0.3, 1.2)
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
ax_timeline.set_ylabel("")
# Style the axes
ax_timeline.spines['top'].set_visible(False)
ax_timeline.spines['right'].set_visible(False)
ax_timeline.spines['left'].set_visible(False)
ax_timeline.spines['bottom'].set_color('#444444')
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
# Add x-axis ticks at regular intervals
tick_interval = max(1, int(total_duration / 10))
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
# Add legend explaining line styles
legend_elements = [
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
]
ax_timeline.legend(
handles=legend_elements,
loc='upper right',
framealpha=0.3,
facecolor='#16213e',
edgecolor='#444444',
fontsize=9,
labelcolor='white'
)
# Save figure
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
plt.close()
return output_path
def main():
parser = argparse.ArgumentParser(
description="Visualize SARM subtask annotations",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID",
)
parser.add_argument(
"--num-episodes",
type=int,
default=5,
help="Number of random episodes to visualize (default: 5)",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Specific episode indices to visualize (overrides --num-episodes)",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Camera/video key to use. If not specified, uses first available.",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
args = parser.parse_args()
console = Console()
# Set random seed if specified
if args.seed is not None:
random.seed(args.seed)
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
# Get video key
if args.video_key:
if args.video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
video_key = args.video_key
else:
video_key = dataset.meta.video_keys[0]
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
console.print(f"[cyan]FPS: {fps}[/cyan]")
# Load annotations
console.print(f"\n[cyan]Loading annotations...[/cyan]")
annotations = load_annotations_from_dataset(dataset.root)
if not annotations:
console.print("[red]Error: No annotations found in dataset[/red]")
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
return
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
# Determine which episodes to visualize
if args.episodes:
episode_indices = args.episodes
# Validate episodes exist
for ep in episode_indices:
if ep not in annotations:
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
episode_indices = [ep for ep in episode_indices if ep in annotations]
else:
# Random selection
available_episodes = list(annotations.keys())
num_to_select = min(args.num_episodes, len(available_episodes))
episode_indices = random.sample(available_episodes, num_to_select)
episode_indices.sort()
if not episode_indices:
console.print("[red]Error: No valid episodes to visualize[/red]")
return
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for ep_idx in episode_indices:
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
annotation = annotations[ep_idx]
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
console.print(f"[red]Video not found: {video_path}[/red]")
continue
# Get episode-specific timestamps within the video file
video_path_key = f"videos/{video_key}/from_timestamp"
video_path_key_to = f"videos/{video_key}/to_timestamp"
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
# Create visualization
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
try:
visualize_episode(
episode_idx=ep_idx,
annotation=annotation,
video_path=video_path,
video_start_timestamp=video_start_timestamp,
video_end_timestamp=video_end_timestamp,
fps=fps,
output_path=output_path,
video_key=video_key,
)
console.print(f"[green]✓ Saved: {output_path}[/green]")
except Exception as e:
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
# Print summary
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
console.print(f"[bold green]Visualization Complete![/bold green]")
console.print(f"[bold green]{'=' * 50}[/bold green]")
console.print(f"Output directory: {output_dir.absolute()}")
console.print(f"Episodes visualized: {len(episode_indices)}")
if __name__ == "__main__":
main()