Fix visualization and change prompt

This commit is contained in:
Pepijn
2025-11-28 12:16:16 +01:00
parent 2fa045eedc
commit b98c70376b
2 changed files with 233 additions and 75 deletions
@@ -96,40 +96,49 @@ def create_sarm_prompt(subtask_list: list[str]) -> str:
"""
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
return f"""You are an expert video annotator. Analyze this robot manipulation video and identify when each subtask occurs.
return f"""# Role
You are an expert Robotics Vision System specializing in temporal action localization. Your task is to segment a video of a robot manipulation demonstration into a sequence of distinct, non-overlapping atomic actions.
WATCH THE ENTIRE VIDEO FIRST:
CRITICAL REQUIREMENTS:
1. You MUST use ONLY these EXACT subtask names (no variations, no other names):
# Input Data
## Allowed Subtask Vocabulary
You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
[
{subtask_str}
2. Identify the start and end timestamp for each subtask that occurs in the video
3. Subtasks should be in chronological order
4. Timestamps should be in MM:SS format (e.g., "00:15" for 15 seconds, "01:30" for 1 minute 30 seconds)
5. Subtasks should cover the entire demonstration without gaps
6. You MUST watch the COMPLETE video from start to finish before making ANY annotations or conclusions
7. Do NOT start annotating until you have seen the entire video length
8. Only after viewing the complete video should you identify the timestamps
9. EACH SUBTASK HAPPENS ONLY ONCE in the video - do not identify the same subtask multiple times
10. Note the exact times when each subtask starts and ends, but make sure to cover the ENTIRE video timeline.
]
FORMAT:
Return a JSON list of subtasks with their timestamps. Each subtask must have:
- "name": One of the exact names from the list above
- "timestamps": object with "start" and "end" fields (MM:SS format)
# Constraints & Logic
1. **Continuous Coverage:** The entire video duration (from 00:00 to the final second) must be accounted for. There can be no gaps between tasks.
2. **Boundary Logic:** The `end` timestamp of one task must be the exact `start` timestamp of the next task.
3. **Linear Progression:** The video represents a single successful demonstration. Each subtask from the vocabulary appears exactly once, in logical chronological order.
4. **Format:** Timestamps must be in "MM:SS" format.
Example structure:
{{
# Step-by-Step Analysis Process
1. **Visual grounding:** Look for the specific visual state changes that define the transition between tasks (e.g., gripper touching object, object lifting off table).
2. **Define Boundaries:** Determine the specific frame where the motion profile changes to fit the next subtask label.
3. **Fill Gaps:** If there is a pause between meaningful actions, append that time to the *preceding* task to ensure continuous coverage.
# Output Format
Provide the output in valid JSON format.
Structure:
{
"subtasks": [
{{"name": "reach_to_object", "timestamps": {{"start": "00:00", "end": "00:05"}}}},
{{"name": "grasp_object", "timestamps": {{"start": "00:05", "end": "00:08"}}}},
...
{
"name": "EXACT_NAME_FROM_LIST",
"timestamps": {
"start": "MM:SS",
"end": "MM:SS"
}
},
{
"name": "EXACT_NAME_FROM_LIST",
"timestamps": {
"start": "MM:SS",
"end": "MM:SS"
}
}
]
}}
Remember: Use ONLY the subtask names provided above, and cover the ENTIRE video timeline."""
}
"""
class VideoAnnotator:
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
@@ -328,9 +337,9 @@ class VideoAnnotator:
# Add video duration to prompt
prompt_with_duration = f"""{self.prompt}
CRITICAL - VIDEO DURATION:
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str} or very close to it."""
# Video Duration:
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your total annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str}."""
# Prepare messages for the model
messages = [
@@ -771,27 +780,11 @@ Examples:
--video-key observation.images.top \\
--num-workers 4 \\
--push-to-hub
# Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3):
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.top \\
--num-workers 3 \\
--gpu-ids 0 2 3 \\
--push-to-hub
# List available cameras:
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp" \\
--max-episodes 0
Performance Tips:
Performance remarks:
- Each worker loads one model instance on its assigned GPU
- The 30B model requires ~60GB VRAM per GPU
- Use --num-workers N for N GPUs to get N× speedup
- Episodes are distributed round-robin across workers
- Use --num-workers N for N GPUs
"""
)
parser.add_argument(
@@ -885,10 +878,7 @@ Performance Tips:
args = parser.parse_args()
# Parse subtask list
subtask_list = [s.strip() for s in args.subtasks.split(",")]
# Parse dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
@@ -906,11 +896,9 @@ Performance Tips:
border_style="cyan"
))
# Load dataset
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
# Get FPS from dataset
fps = dataset.fps
console.print(f"[cyan]Dataset FPS: {fps}[/cyan]")
+191 -21
View File
@@ -30,6 +30,7 @@ Example usage:
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional
@@ -38,12 +39,17 @@ import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.sarm_utils import pad_state_to_max_dim
from lerobot.policies.sarm.sarm_utils import (
pad_state_to_max_dim,
compute_tau,
compute_cumulative_progress_batch,
)
from lerobot.datasets.utils import load_stats
@@ -328,6 +334,110 @@ def run_inference(
return np.array(all_progress), np.array(all_stages)
def compute_ground_truth_progress(
dataset: LeRobotDataset,
episode_index: int,
temporal_proportions: dict[str, float],
subtask_names_ordered: list[str],
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""
Compute ground truth progress and stage labels for an episode using annotations.
Uses SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode
temporal_proportions: Dict mapping subtask name to proportion
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
Returns:
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
"""
# Load episode metadata
episodes_df = dataset.meta.episodes.to_pandas()
# Check if annotations exist
if "subtask_names" not in episodes_df.columns:
logger.warning("No subtask_names column found in episodes metadata")
return None, None
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
logger.warning(f"No annotations found for episode {episode_index}")
return None, None
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
# Get episode boundaries
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
num_frames = ep_end - ep_start
# Get temporal proportions as ordered list
temporal_proportions_list = [
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
]
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
logger.info(f"Subtask names in episode: {ep_subtask_names}")
logger.info(f"Subtask start frames: {subtask_start_frames}")
logger.info(f"Subtask end frames: {subtask_end_frames}")
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
# Compute ground truth for each frame
gt_progress = np.zeros(num_frames)
gt_stages = np.zeros(num_frames, dtype=np.int32)
for frame_rel in range(num_frames):
# Find which subtask this frame belongs to
found = False
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
if frame_rel >= start_frame and frame_rel <= end_frame:
# Found the subtask - get its global index
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
# Compute τ_t using utility function
tau = compute_tau(frame_rel, start_frame, end_frame)
# Compute cumulative progress using utility function
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
found = True
break
if not found:
# Handle frames outside annotated subtasks
if frame_rel < subtask_start_frames[0]:
gt_progress[frame_rel] = 0.0
gt_stages[frame_rel] = 0
elif frame_rel > subtask_end_frames[-1]:
gt_progress[frame_rel] = 1.0
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
else:
# Between subtasks - find previous subtask
for j in range(len(ep_subtask_names) - 1):
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
name = ep_subtask_names[j]
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
break
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
return gt_progress, gt_stages
def visualize_predictions(
frames: np.ndarray,
progress_predictions: np.ndarray,
@@ -337,10 +447,12 @@ def visualize_predictions(
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None
temporal_proportions: dict[str, float] | None = None,
ground_truth_progress: np.ndarray | None = None,
ground_truth_stages: np.ndarray | None = None,
):
"""
Create visualization of SARM predictions.
Create visualization of SARM predictions with optional ground truth comparison.
Args:
frames: Video frames (num_frames, H, W, C)
@@ -348,11 +460,12 @@ def visualize_predictions(
stage_predictions: Stage probabilities (num_frames, num_stages)
task_description: Task description
output_path: Path to save the figure
show_frames: Whether to include sample frames
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
ground_truth_progress: Optional ground truth progress array (num_frames,)
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
@@ -376,10 +489,16 @@ def visualize_predictions(
# Plot 1: Progress over time
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
# Plot ground truth if available
if ground_truth_progress is not None:
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
linestyle='--', label='Ground Truth Progress')
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'SARM Task Progress & Stage Prediction\nTask: "{task_description}"',
fontsize=14, fontweight='bold')
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
@@ -391,6 +510,11 @@ def visualize_predictions(
f'Max Progress: {progress_predictions.max():.3f}\n'
f'Mean Progress: {progress_predictions.mean():.3f}'
)
if ground_truth_progress is not None:
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
stats_text += f'\nMSE vs GT: {mse:.4f}'
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
@@ -398,6 +522,21 @@ def visualize_predictions(
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=stage_labels)
# Plot ground truth stage as vertical bands or markers
if ground_truth_stages is not None:
# Find stage transition points in ground truth
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
for change_idx in stage_changes:
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
# Add small markers at bottom showing GT stage
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
c=[stage_colors[s] for s in ground_truth_stages[::30]],
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
@@ -540,20 +679,42 @@ def main():
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model's dataset meta
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
if 'temporal_proportions' in model.dataset_stats:
temporal_proportions = model.dataset_stats['temporal_proportions']
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
# Try to load temporal proportions from model config
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
temporal_proportions = {
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
}
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
# # Fallback: try to load from dataset meta
# if temporal_proportions is None and subtask_names is not None:
# import json
# proportions_path = dataset.root / "meta" / "temporal_proportions.json"
# if proportions_path.exists():
# with open(proportions_path, 'r') as f:
# temporal_proportions = json.load(f)
# logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None:
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Also extract subtask names from proportions if not already set
if subtask_names is None:
subtask_names = sorted(temporal_proportions.keys())
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
# Compute ground truth progress if annotations are available
ground_truth_progress = None
ground_truth_stages = None
if temporal_proportions is not None and subtask_names is not None:
logger.info("Attempting to compute ground truth progress from annotations...")
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
dataset,
args.episode_index,
temporal_proportions,
subtask_names
)
if ground_truth_progress is None:
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
else:
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
@@ -567,11 +728,20 @@ def main():
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions
temporal_proportions=temporal_proportions,
ground_truth_progress=ground_truth_progress,
ground_truth_stages=ground_truth_stages,
)
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions)
save_dict = {
'progress': progress_predictions,
'stages': stage_predictions
}
if ground_truth_progress is not None:
save_dict['gt_progress'] = ground_truth_progress
save_dict['gt_stages'] = ground_truth_stages
np.savez(predictions_path, **save_dict)
logger.info(f"Saved predictions to {predictions_path}")
logger.info(f"\nVisualization: {output_path}")