mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
add visualize subtask annotations
This commit is contained in:
@@ -0,0 +1,525 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Visualize SARM Subtask Annotations
|
||||
|
||||
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
|
||||
For each episode, it shows:
|
||||
- A timeline with dashed vertical lines at subtask boundaries
|
||||
- Sample frames from the episode at key points (start, middle, end of each subtask)
|
||||
- Color-coded subtask segments
|
||||
|
||||
Usage:
|
||||
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.lines import Line2D
|
||||
from rich.console import Console
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
|
||||
|
||||
def timestamp_to_seconds(timestamp: str) -> float:
|
||||
"""Convert MM:SS or SS timestamp to seconds"""
|
||||
parts = timestamp.split(":")
|
||||
if len(parts) == 2:
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
else:
|
||||
return int(parts[0])
|
||||
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
|
||||
"""
|
||||
Load annotations from LeRobot dataset parquet files.
|
||||
|
||||
Reads subtask annotations from the episodes metadata parquet files.
|
||||
"""
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
|
||||
if episodes_dataset is None or len(episodes_dataset) == 0:
|
||||
return {}
|
||||
|
||||
# Check if subtask columns exist
|
||||
if "subtask_names" not in episodes_dataset.column_names:
|
||||
return {}
|
||||
|
||||
# Convert to pandas DataFrame for easier access
|
||||
episodes_df = episodes_dataset.to_pandas()
|
||||
|
||||
annotations = {}
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
|
||||
|
||||
# Skip episodes without annotations
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
continue
|
||||
|
||||
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
|
||||
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
|
||||
|
||||
# Reconstruct SubtaskAnnotation from stored data
|
||||
subtasks = []
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Convert seconds back to MM:SS format
|
||||
start_sec = int(start_times[i])
|
||||
end_sec = int(end_times[i])
|
||||
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
|
||||
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
|
||||
|
||||
subtasks.append(
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(start=start_str, end=end_str)
|
||||
)
|
||||
)
|
||||
|
||||
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
# Color palette for subtasks (colorblind-friendly)
|
||||
SUBTASK_COLORS = [
|
||||
"#E69F00", # Orange
|
||||
"#56B4E9", # Sky blue
|
||||
"#009E73", # Bluish green
|
||||
"#F0E442", # Yellow
|
||||
"#0072B2", # Blue
|
||||
"#D55E00", # Vermillion
|
||||
"#CC79A7", # Reddish purple
|
||||
"#999999", # Gray
|
||||
]
|
||||
|
||||
|
||||
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
|
||||
"""Extract a single frame from video at given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
return None
|
||||
|
||||
# Set position to timestamp
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if ret:
|
||||
# Convert BGR to RGB
|
||||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
return None
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
episode_idx: int,
|
||||
annotation,
|
||||
video_path: Path,
|
||||
video_start_timestamp: float,
|
||||
video_end_timestamp: float,
|
||||
fps: int,
|
||||
output_path: Path,
|
||||
video_key: str,
|
||||
):
|
||||
"""
|
||||
Create visualization for a single episode.
|
||||
|
||||
Shows:
|
||||
- Top row: Sample frames from the episode (one per subtask)
|
||||
- Bottom: Timeline with subtask segments and boundary lines
|
||||
"""
|
||||
subtasks = annotation.subtasks
|
||||
num_subtasks = len(subtasks)
|
||||
|
||||
if num_subtasks == 0:
|
||||
print(f"No subtasks found for episode {episode_idx}")
|
||||
return
|
||||
|
||||
# Calculate episode duration
|
||||
episode_duration = video_end_timestamp - video_start_timestamp
|
||||
|
||||
# Extract sample frames - get frame from middle of each subtask
|
||||
sample_frames = []
|
||||
frame_timestamps = []
|
||||
|
||||
for subtask in subtasks:
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
mid_sec = (start_sec + end_sec) / 2
|
||||
|
||||
# Convert to video timestamp (add video_start_timestamp offset)
|
||||
video_timestamp = video_start_timestamp + mid_sec
|
||||
frame_timestamps.append(mid_sec)
|
||||
|
||||
frame = extract_frame_from_video(video_path, video_timestamp)
|
||||
sample_frames.append(frame)
|
||||
|
||||
# Create figure
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
# Use a dark background for better contrast
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Calculate grid layout
|
||||
# Top section: frames (variable number of columns based on subtasks)
|
||||
# Bottom section: timeline
|
||||
|
||||
# Create gridspec
|
||||
gs = fig.add_gridspec(
|
||||
2, max(num_subtasks, 1),
|
||||
height_ratios=[2, 1],
|
||||
hspace=0.3,
|
||||
wspace=0.1,
|
||||
left=0.05, right=0.95,
|
||||
top=0.88, bottom=0.1
|
||||
)
|
||||
|
||||
# Add title
|
||||
fig.suptitle(
|
||||
f"Episode {episode_idx} - Subtask Annotations",
|
||||
fontsize=18,
|
||||
fontweight='bold',
|
||||
color='white',
|
||||
y=0.96
|
||||
)
|
||||
|
||||
# Add subtitle with video info
|
||||
fig.text(
|
||||
0.5, 0.91,
|
||||
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
|
||||
ha='center',
|
||||
fontsize=11,
|
||||
color='#888888'
|
||||
)
|
||||
|
||||
# Plot sample frames
|
||||
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
|
||||
ax = fig.add_subplot(gs[0, i])
|
||||
ax.set_facecolor('#16213e')
|
||||
|
||||
if frame is not None:
|
||||
ax.imshow(frame)
|
||||
else:
|
||||
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
|
||||
fontsize=12, color='white', transform=ax.transAxes)
|
||||
|
||||
ax.set_title(
|
||||
f"{subtask.name}",
|
||||
fontsize=10,
|
||||
fontweight='bold',
|
||||
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
|
||||
pad=8
|
||||
)
|
||||
ax.axis('off')
|
||||
|
||||
# Add frame timestamp below
|
||||
ax.text(
|
||||
0.5, -0.08,
|
||||
f"t={frame_timestamps[i]:.1f}s",
|
||||
ha='center',
|
||||
fontsize=9,
|
||||
color='#888888',
|
||||
transform=ax.transAxes
|
||||
)
|
||||
|
||||
# Create timeline subplot spanning all columns
|
||||
ax_timeline = fig.add_subplot(gs[1, :])
|
||||
ax_timeline.set_facecolor('#16213e')
|
||||
|
||||
# Get total duration from last subtask end time
|
||||
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
|
||||
|
||||
# Draw subtask segments as colored bars
|
||||
bar_height = 0.6
|
||||
bar_y = 0.5
|
||||
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
|
||||
|
||||
# Draw segment bar
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(start_sec, bar_y - bar_height/2),
|
||||
end_sec - start_sec,
|
||||
bar_height,
|
||||
boxstyle="round,pad=0.02,rounding_size=0.1",
|
||||
facecolor=color,
|
||||
edgecolor='white',
|
||||
linewidth=1.5,
|
||||
alpha=0.85
|
||||
)
|
||||
ax_timeline.add_patch(rect)
|
||||
|
||||
# Add subtask label inside bar
|
||||
mid_x = (start_sec + end_sec) / 2
|
||||
duration = end_sec - start_sec
|
||||
|
||||
# Only add text if segment is wide enough
|
||||
if duration > total_duration * 0.08:
|
||||
ax_timeline.text(
|
||||
mid_x, bar_y,
|
||||
subtask.name,
|
||||
ha='center', va='center',
|
||||
fontsize=9,
|
||||
fontweight='bold',
|
||||
color='black' if i in [3] else 'white', # Yellow needs dark text
|
||||
rotation=0 if duration > total_duration * 0.15 else 45
|
||||
)
|
||||
|
||||
# Draw boundary lines (dashed vertical lines between subtasks)
|
||||
boundary_times = []
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
|
||||
# Add start boundary (except for first subtask at t=0)
|
||||
if i == 0 and start_sec > 0:
|
||||
boundary_times.append(start_sec)
|
||||
elif i > 0:
|
||||
boundary_times.append(start_sec)
|
||||
|
||||
# Add end boundary for last subtask
|
||||
if i == len(subtasks) - 1:
|
||||
boundary_times.append(end_sec)
|
||||
|
||||
# Draw dashed lines at boundaries
|
||||
for t in boundary_times:
|
||||
ax_timeline.axvline(
|
||||
x=t,
|
||||
ymin=0.1, ymax=0.9,
|
||||
color='white',
|
||||
linestyle='--',
|
||||
linewidth=2,
|
||||
alpha=0.9
|
||||
)
|
||||
|
||||
# Add time label below line
|
||||
ax_timeline.text(
|
||||
t, 0.0,
|
||||
f"{int(t//60):02d}:{int(t%60):02d}",
|
||||
ha='center', va='top',
|
||||
fontsize=8,
|
||||
color='#cccccc'
|
||||
)
|
||||
|
||||
# Add start line at t=0
|
||||
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
|
||||
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
|
||||
|
||||
# Configure timeline axes
|
||||
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
|
||||
ax_timeline.set_ylim(-0.3, 1.2)
|
||||
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
|
||||
ax_timeline.set_ylabel("")
|
||||
|
||||
# Style the axes
|
||||
ax_timeline.spines['top'].set_visible(False)
|
||||
ax_timeline.spines['right'].set_visible(False)
|
||||
ax_timeline.spines['left'].set_visible(False)
|
||||
ax_timeline.spines['bottom'].set_color('#444444')
|
||||
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
|
||||
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
|
||||
|
||||
# Add x-axis ticks at regular intervals
|
||||
tick_interval = max(1, int(total_duration / 10))
|
||||
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
|
||||
|
||||
# Add legend explaining line styles
|
||||
legend_elements = [
|
||||
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
|
||||
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
|
||||
]
|
||||
ax_timeline.legend(
|
||||
handles=legend_elements,
|
||||
loc='upper right',
|
||||
framealpha=0.3,
|
||||
facecolor='#16213e',
|
||||
edgecolor='#444444',
|
||||
fontsize=9,
|
||||
labelcolor='white'
|
||||
)
|
||||
|
||||
# Save figure
|
||||
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualize SARM subtask annotations",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-episodes",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of random episodes to visualize (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Specific episode indices to visualize (overrides --num-episodes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Camera/video key to use. If not specified, uses first available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./subtask_viz",
|
||||
help="Output directory for visualizations (default: ./subtask_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
console = Console()
|
||||
|
||||
# Set random seed if specified
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||||
fps = dataset.fps
|
||||
|
||||
# Get video key
|
||||
if args.video_key:
|
||||
if args.video_key not in dataset.meta.video_keys:
|
||||
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
|
||||
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
||||
return
|
||||
video_key = args.video_key
|
||||
else:
|
||||
video_key = dataset.meta.video_keys[0]
|
||||
|
||||
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
|
||||
console.print(f"[cyan]FPS: {fps}[/cyan]")
|
||||
|
||||
# Load annotations
|
||||
console.print(f"\n[cyan]Loading annotations...[/cyan]")
|
||||
annotations = load_annotations_from_dataset(dataset.root)
|
||||
|
||||
if not annotations:
|
||||
console.print("[red]Error: No annotations found in dataset[/red]")
|
||||
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
|
||||
|
||||
# Determine which episodes to visualize
|
||||
if args.episodes:
|
||||
episode_indices = args.episodes
|
||||
# Validate episodes exist
|
||||
for ep in episode_indices:
|
||||
if ep not in annotations:
|
||||
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
|
||||
episode_indices = [ep for ep in episode_indices if ep in annotations]
|
||||
else:
|
||||
# Random selection
|
||||
available_episodes = list(annotations.keys())
|
||||
num_to_select = min(args.num_episodes, len(available_episodes))
|
||||
episode_indices = random.sample(available_episodes, num_to_select)
|
||||
episode_indices.sort()
|
||||
|
||||
if not episode_indices:
|
||||
console.print("[red]Error: No valid episodes to visualize[/red]")
|
||||
return
|
||||
|
||||
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate visualizations
|
||||
for ep_idx in episode_indices:
|
||||
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
|
||||
|
||||
annotation = annotations[ep_idx]
|
||||
|
||||
# Get video path and timestamps
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
|
||||
if not video_path.exists():
|
||||
console.print(f"[red]Video not found: {video_path}[/red]")
|
||||
continue
|
||||
|
||||
# Get episode-specific timestamps within the video file
|
||||
video_path_key = f"videos/{video_key}/from_timestamp"
|
||||
video_path_key_to = f"videos/{video_key}/to_timestamp"
|
||||
|
||||
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
|
||||
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
|
||||
|
||||
# Create visualization
|
||||
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
|
||||
|
||||
try:
|
||||
visualize_episode(
|
||||
episode_idx=ep_idx,
|
||||
annotation=annotation,
|
||||
video_path=video_path,
|
||||
video_start_timestamp=video_start_timestamp,
|
||||
video_end_timestamp=video_end_timestamp,
|
||||
fps=fps,
|
||||
output_path=output_path,
|
||||
video_key=video_key,
|
||||
)
|
||||
console.print(f"[green]✓ Saved: {output_path}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
|
||||
|
||||
# Print summary
|
||||
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"[bold green]Visualization Complete![/bold green]")
|
||||
console.print(f"[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"Output directory: {output_dir.absolute()}")
|
||||
console.print(f"Episodes visualized: {len(episode_indices)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user