From b98c70376b239cc5213053a599b21f9ef1e992a6 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 28 Nov 2025 12:16:16 +0100 Subject: [PATCH] Fix visualization and change prompt --- .../dataset_annotation/subtask_annotation.py | 96 ++++---- scripts/visualize_sarm_predictions.py | 212 ++++++++++++++++-- 2 files changed, 233 insertions(+), 75 deletions(-) diff --git a/examples/dataset_annotation/subtask_annotation.py b/examples/dataset_annotation/subtask_annotation.py index 6c5148923..97dbf10f6 100644 --- a/examples/dataset_annotation/subtask_annotation.py +++ b/examples/dataset_annotation/subtask_annotation.py @@ -96,40 +96,49 @@ def create_sarm_prompt(subtask_list: list[str]) -> str: """ subtask_str = "\n".join([f" - {name}" for name in subtask_list]) - return f"""You are an expert video annotator. Analyze this robot manipulation video and identify when each subtask occurs. + return f"""# Role +You are an expert Robotics Vision System specializing in temporal action localization. Your task is to segment a video of a robot manipulation demonstration into a sequence of distinct, non-overlapping atomic actions. -WATCH THE ENTIRE VIDEO FIRST: - - -CRITICAL REQUIREMENTS: -1. You MUST use ONLY these EXACT subtask names (no variations, no other names): +# Input Data +## Allowed Subtask Vocabulary +You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones: +[ {subtask_str} -2. Identify the start and end timestamp for each subtask that occurs in the video -3. Subtasks should be in chronological order -4. Timestamps should be in MM:SS format (e.g., "00:15" for 15 seconds, "01:30" for 1 minute 30 seconds) -5. Subtasks should cover the entire demonstration without gaps -6. You MUST watch the COMPLETE video from start to finish before making ANY annotations or conclusions -7. Do NOT start annotating until you have seen the entire video length -8. Only after viewing the complete video should you identify the timestamps -9. EACH SUBTASK HAPPENS ONLY ONCE in the video - do not identify the same subtask multiple times -10. Note the exact times when each subtask starts and ends, but make sure to cover the ENTIRE video timeline. +] -FORMAT: -Return a JSON list of subtasks with their timestamps. Each subtask must have: -- "name": One of the exact names from the list above -- "timestamps": object with "start" and "end" fields (MM:SS format) +# Constraints & Logic +1. **Continuous Coverage:** The entire video duration (from 00:00 to the final second) must be accounted for. There can be no gaps between tasks. +2. **Boundary Logic:** The `end` timestamp of one task must be the exact `start` timestamp of the next task. +3. **Linear Progression:** The video represents a single successful demonstration. Each subtask from the vocabulary appears exactly once, in logical chronological order. +4. **Format:** Timestamps must be in "MM:SS" format. -Example structure: -{{ +# Step-by-Step Analysis Process +1. **Visual grounding:** Look for the specific visual state changes that define the transition between tasks (e.g., gripper touching object, object lifting off table). +2. **Define Boundaries:** Determine the specific frame where the motion profile changes to fit the next subtask label. +3. **Fill Gaps:** If there is a pause between meaningful actions, append that time to the *preceding* task to ensure continuous coverage. + +# Output Format +Provide the output in valid JSON format. +Structure: +{ "subtasks": [ - {{"name": "reach_to_object", "timestamps": {{"start": "00:00", "end": "00:05"}}}}, - {{"name": "grasp_object", "timestamps": {{"start": "00:05", "end": "00:08"}}}}, - ... + { + "name": "EXACT_NAME_FROM_LIST", + "timestamps": { + "start": "MM:SS", + "end": "MM:SS" + } + }, + { + "name": "EXACT_NAME_FROM_LIST", + "timestamps": { + "start": "MM:SS", + "end": "MM:SS" + } + } ] -}} - -Remember: Use ONLY the subtask names provided above, and cover the ENTIRE video timeline.""" - +} +""" class VideoAnnotator: """Annotates robot manipulation videos using local Qwen3-VL model on GPU""" @@ -328,9 +337,9 @@ class VideoAnnotator: # Add video duration to prompt prompt_with_duration = f"""{self.prompt} -CRITICAL - VIDEO DURATION: -The video is {duration_str} long ({duration_seconds:.1f} seconds). Your annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}. -Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str} or very close to it.""" +# Video Duration: +The video is {duration_str} long ({duration_seconds:.1f} seconds). Your total annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}. +Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str}.""" # Prepare messages for the model messages = [ @@ -771,27 +780,11 @@ Examples: --video-key observation.images.top \\ --num-workers 4 \\ --push-to-hub - - # Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3): - python subtask_annotation.py \\ - --repo-id pepijn223/mydataset \\ - --subtasks "reach,grasp,lift,place" \\ - --video-key observation.images.top \\ - --num-workers 3 \\ - --gpu-ids 0 2 3 \\ - --push-to-hub - - # List available cameras: - python subtask_annotation.py \\ - --repo-id pepijn223/mydataset \\ - --subtasks "reach,grasp" \\ - --max-episodes 0 -Performance Tips: +Performance remarks: - Each worker loads one model instance on its assigned GPU - The 30B model requires ~60GB VRAM per GPU - - Use --num-workers N for N GPUs to get N× speedup - - Episodes are distributed round-robin across workers + - Use --num-workers N for N GPUs """ ) parser.add_argument( @@ -885,10 +878,7 @@ Performance Tips: args = parser.parse_args() - # Parse subtask list subtask_list = [s.strip() for s in args.subtasks.split(",")] - - # Parse dtype dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, @@ -906,11 +896,9 @@ Performance Tips: border_style="cyan" )) - # Load dataset console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]") dataset = LeRobotDataset(args.repo_id, download_videos=True) - # Get FPS from dataset fps = dataset.fps console.print(f"[cyan]Dataset FPS: {fps}[/cyan]") diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index a929cf233..352438674 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -30,6 +30,7 @@ Example usage: """ import argparse +import json import logging from pathlib import Path from typing import Optional @@ -38,12 +39,17 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches import numpy as np +import pandas as pd import torch from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.sarm.modeling_sarm import SARMRewardModel -from lerobot.policies.sarm.sarm_utils import pad_state_to_max_dim +from lerobot.policies.sarm.sarm_utils import ( + pad_state_to_max_dim, + compute_tau, + compute_cumulative_progress_batch, +) from lerobot.datasets.utils import load_stats @@ -328,6 +334,110 @@ def run_inference( return np.array(all_progress), np.array(all_stages) +def compute_ground_truth_progress( + dataset: LeRobotDataset, + episode_index: int, + temporal_proportions: dict[str, float], + subtask_names_ordered: list[str], +) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: + """ + Compute ground truth progress and stage labels for an episode using annotations. + + Uses SARM Paper Formula (2): + y_t = P_{k-1} + ᾱ_k × τ_t + + where: + - τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress + - P_{k-1} is cumulative prior (sum of previous subtask proportions) + - ᾱ_k is the temporal proportion for subtask k + + Args: + dataset: LeRobotDataset instance + episode_index: Index of the episode + temporal_proportions: Dict mapping subtask name to proportion + subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing) + + Returns: + Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations + """ + # Load episode metadata + episodes_df = dataset.meta.episodes.to_pandas() + + # Check if annotations exist + if "subtask_names" not in episodes_df.columns: + logger.warning("No subtask_names column found in episodes metadata") + return None, None + + ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"] + if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)): + logger.warning(f"No annotations found for episode {episode_index}") + return None, None + + subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"] + subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"] + + # Get episode boundaries + ep_start = dataset.meta.episodes["dataset_from_index"][episode_index] + ep_end = dataset.meta.episodes["dataset_to_index"][episode_index] + num_frames = ep_end - ep_start + + # Get temporal proportions as ordered list + temporal_proportions_list = [ + temporal_proportions.get(name, 0.0) for name in subtask_names_ordered + ] + + logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks") + logger.info(f"Subtask names in episode: {ep_subtask_names}") + logger.info(f"Subtask start frames: {subtask_start_frames}") + logger.info(f"Subtask end frames: {subtask_end_frames}") + logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}") + + # Compute ground truth for each frame + gt_progress = np.zeros(num_frames) + gt_stages = np.zeros(num_frames, dtype=np.int32) + + for frame_rel in range(num_frames): + # Find which subtask this frame belongs to + found = False + for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)): + if frame_rel >= start_frame and frame_rel <= end_frame: + # Found the subtask - get its global index + stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0 + + # Compute τ_t using utility function + tau = compute_tau(frame_rel, start_frame, end_frame) + + # Compute cumulative progress using utility function + progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list) + + gt_progress[frame_rel] = progress + gt_stages[frame_rel] = stage_idx + found = True + break + + if not found: + # Handle frames outside annotated subtasks + if frame_rel < subtask_start_frames[0]: + gt_progress[frame_rel] = 0.0 + gt_stages[frame_rel] = 0 + elif frame_rel > subtask_end_frames[-1]: + gt_progress[frame_rel] = 1.0 + gt_stages[frame_rel] = len(subtask_names_ordered) - 1 + else: + # Between subtasks - find previous subtask + for j in range(len(ep_subtask_names) - 1): + if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]: + name = ep_subtask_names[j] + stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j + progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list) + gt_progress[frame_rel] = progress + gt_stages[frame_rel] = stage_idx + break + + logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}") + return gt_progress, gt_stages + + def visualize_predictions( frames: np.ndarray, progress_predictions: np.ndarray, @@ -337,10 +447,12 @@ def visualize_predictions( num_sample_frames: int = 8, figsize: tuple = (14, 8), subtask_names: list[str] | None = None, - temporal_proportions: dict[str, float] | None = None + temporal_proportions: dict[str, float] | None = None, + ground_truth_progress: np.ndarray | None = None, + ground_truth_stages: np.ndarray | None = None, ): """ - Create visualization of SARM predictions. + Create visualization of SARM predictions with optional ground truth comparison. Args: frames: Video frames (num_frames, H, W, C) @@ -348,11 +460,12 @@ def visualize_predictions( stage_predictions: Stage probabilities (num_frames, num_stages) task_description: Task description output_path: Path to save the figure - show_frames: Whether to include sample frames num_sample_frames: Number of frames to show figsize: Figure size (width, height) subtask_names: Optional list of subtask names for labeling temporal_proportions: Optional dict of temporal proportions for each subtask + ground_truth_progress: Optional ground truth progress array (num_frames,) + ground_truth_stages: Optional ground truth stage indices array (num_frames,) """ num_stages = stage_predictions.shape[1] stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) @@ -376,10 +489,16 @@ def visualize_predictions( # Plot 1: Progress over time ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress') ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB') + + # Plot ground truth if available + if ground_truth_progress is not None: + ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745', + linestyle='--', label='Ground Truth Progress') + ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745') + ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1) ax_progress.set_ylabel('Task Progress', fontsize=12) - ax_progress.set_title(f'SARM Task Progress & Stage Prediction\nTask: "{task_description}"', - fontsize=14, fontweight='bold') + ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold') ax_progress.grid(True, alpha=0.3) ax_progress.set_ylim(-0.05, 1.1) ax_progress.legend(loc='upper left') @@ -391,6 +510,11 @@ def visualize_predictions( f'Max Progress: {progress_predictions.max():.3f}\n' f'Mean Progress: {progress_predictions.mean():.3f}' ) + if ground_truth_progress is not None: + mse = np.mean((progress_predictions - ground_truth_progress) ** 2) + stats_text += f'\nMSE vs GT: {mse:.4f}' + stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}' + ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes, fontsize=10, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) @@ -398,6 +522,21 @@ def visualize_predictions( # Plot 2: Stage predictions (stacked area plot) ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)], colors=stage_colors, alpha=0.8, labels=stage_labels) + + # Plot ground truth stage as vertical bands or markers + if ground_truth_stages is not None: + # Find stage transition points in ground truth + stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1 + for change_idx in stage_changes: + ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5) + ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1) + + # Add small markers at bottom showing GT stage + gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1) + ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02, + c=[stage_colors[s] for s in ground_truth_stages[::30]], + s=20, marker='|', alpha=0.8, label='GT Stage Markers') + ax_stages.set_xlabel('Frame Index', fontsize=12) ax_stages.set_ylabel('Stage Probability', fontsize=12) ax_stages.set_ylim(0, 1) @@ -540,20 +679,42 @@ def main(): subtask_names = model.config.subtask_names logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}") - # Try to load temporal proportions from model's dataset meta - if hasattr(model, 'dataset_stats') and model.dataset_stats is not None: - if 'temporal_proportions' in model.dataset_stats: - temporal_proportions = model.dataset_stats['temporal_proportions'] - logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}") + # Try to load temporal proportions from model config + if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None: + temporal_proportions = { + name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions) + } + logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}") - # # Fallback: try to load from dataset meta - # if temporal_proportions is None and subtask_names is not None: - # import json - # proportions_path = dataset.root / "meta" / "temporal_proportions.json" - # if proportions_path.exists(): - # with open(proportions_path, 'r') as f: - # temporal_proportions = json.load(f) - # logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") + # Fallback: try to load from dataset meta + if temporal_proportions is None: + proportions_path = dataset.root / "meta" / "temporal_proportions.json" + if proportions_path.exists(): + with open(proportions_path, 'r') as f: + temporal_proportions = json.load(f) + logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") + + # Also extract subtask names from proportions if not already set + if subtask_names is None: + subtask_names = sorted(temporal_proportions.keys()) + logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}") + + # Compute ground truth progress if annotations are available + ground_truth_progress = None + ground_truth_stages = None + + if temporal_proportions is not None and subtask_names is not None: + logger.info("Attempting to compute ground truth progress from annotations...") + ground_truth_progress, ground_truth_stages = compute_ground_truth_progress( + dataset, + args.episode_index, + temporal_proportions, + subtask_names + ) + if ground_truth_progress is None: + logger.warning("⚠ Ground truth not available - annotations may be missing for this episode") + else: + logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available") output_dir = Path(args.output_dir) output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" @@ -567,11 +728,20 @@ def main(): num_sample_frames=args.num_sample_frames, figsize=tuple(args.figsize), subtask_names=subtask_names, - temporal_proportions=temporal_proportions + temporal_proportions=temporal_proportions, + ground_truth_progress=ground_truth_progress, + ground_truth_stages=ground_truth_stages, ) predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz" - np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions) + save_dict = { + 'progress': progress_predictions, + 'stages': stage_predictions + } + if ground_truth_progress is not None: + save_dict['gt_progress'] = ground_truth_progress + save_dict['gt_stages'] = ground_truth_stages + np.savez(predictions_path, **save_dict) logger.info(f"Saved predictions to {predictions_path}") logger.info(f"\nVisualization: {output_path}")