mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
Fix visualization and change prompt
This commit is contained in:
@@ -96,40 +96,49 @@ def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||||
"""
|
||||
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
|
||||
|
||||
return f"""You are an expert video annotator. Analyze this robot manipulation video and identify when each subtask occurs.
|
||||
return f"""# Role
|
||||
You are an expert Robotics Vision System specializing in temporal action localization. Your task is to segment a video of a robot manipulation demonstration into a sequence of distinct, non-overlapping atomic actions.
|
||||
|
||||
WATCH THE ENTIRE VIDEO FIRST:
|
||||
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
1. You MUST use ONLY these EXACT subtask names (no variations, no other names):
|
||||
# Input Data
|
||||
## Allowed Subtask Vocabulary
|
||||
You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
|
||||
[
|
||||
{subtask_str}
|
||||
2. Identify the start and end timestamp for each subtask that occurs in the video
|
||||
3. Subtasks should be in chronological order
|
||||
4. Timestamps should be in MM:SS format (e.g., "00:15" for 15 seconds, "01:30" for 1 minute 30 seconds)
|
||||
5. Subtasks should cover the entire demonstration without gaps
|
||||
6. You MUST watch the COMPLETE video from start to finish before making ANY annotations or conclusions
|
||||
7. Do NOT start annotating until you have seen the entire video length
|
||||
8. Only after viewing the complete video should you identify the timestamps
|
||||
9. EACH SUBTASK HAPPENS ONLY ONCE in the video - do not identify the same subtask multiple times
|
||||
10. Note the exact times when each subtask starts and ends, but make sure to cover the ENTIRE video timeline.
|
||||
]
|
||||
|
||||
FORMAT:
|
||||
Return a JSON list of subtasks with their timestamps. Each subtask must have:
|
||||
- "name": One of the exact names from the list above
|
||||
- "timestamps": object with "start" and "end" fields (MM:SS format)
|
||||
# Constraints & Logic
|
||||
1. **Continuous Coverage:** The entire video duration (from 00:00 to the final second) must be accounted for. There can be no gaps between tasks.
|
||||
2. **Boundary Logic:** The `end` timestamp of one task must be the exact `start` timestamp of the next task.
|
||||
3. **Linear Progression:** The video represents a single successful demonstration. Each subtask from the vocabulary appears exactly once, in logical chronological order.
|
||||
4. **Format:** Timestamps must be in "MM:SS" format.
|
||||
|
||||
Example structure:
|
||||
{{
|
||||
# Step-by-Step Analysis Process
|
||||
1. **Visual grounding:** Look for the specific visual state changes that define the transition between tasks (e.g., gripper touching object, object lifting off table).
|
||||
2. **Define Boundaries:** Determine the specific frame where the motion profile changes to fit the next subtask label.
|
||||
3. **Fill Gaps:** If there is a pause between meaningful actions, append that time to the *preceding* task to ensure continuous coverage.
|
||||
|
||||
# Output Format
|
||||
Provide the output in valid JSON format.
|
||||
Structure:
|
||||
{
|
||||
"subtasks": [
|
||||
{{"name": "reach_to_object", "timestamps": {{"start": "00:00", "end": "00:05"}}}},
|
||||
{{"name": "grasp_object", "timestamps": {{"start": "00:05", "end": "00:08"}}}},
|
||||
...
|
||||
{
|
||||
"name": "EXACT_NAME_FROM_LIST",
|
||||
"timestamps": {
|
||||
"start": "MM:SS",
|
||||
"end": "MM:SS"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "EXACT_NAME_FROM_LIST",
|
||||
"timestamps": {
|
||||
"start": "MM:SS",
|
||||
"end": "MM:SS"
|
||||
}
|
||||
}
|
||||
]
|
||||
}}
|
||||
|
||||
Remember: Use ONLY the subtask names provided above, and cover the ENTIRE video timeline."""
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
class VideoAnnotator:
|
||||
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
|
||||
@@ -328,9 +337,9 @@ class VideoAnnotator:
|
||||
# Add video duration to prompt
|
||||
prompt_with_duration = f"""{self.prompt}
|
||||
|
||||
CRITICAL - VIDEO DURATION:
|
||||
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
|
||||
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str} or very close to it."""
|
||||
# Video Duration:
|
||||
The video is {duration_str} long ({duration_seconds:.1f} seconds). Your total annotations MUST cover the ENTIRE duration from 00:00 to {duration_str}.
|
||||
Do NOT stop annotating before the video ends. Make sure your last subtask ends at {duration_str}."""
|
||||
|
||||
# Prepare messages for the model
|
||||
messages = [
|
||||
@@ -771,27 +780,11 @@ Examples:
|
||||
--video-key observation.images.top \\
|
||||
--num-workers 4 \\
|
||||
--push-to-hub
|
||||
|
||||
# Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3):
|
||||
python subtask_annotation.py \\
|
||||
--repo-id pepijn223/mydataset \\
|
||||
--subtasks "reach,grasp,lift,place" \\
|
||||
--video-key observation.images.top \\
|
||||
--num-workers 3 \\
|
||||
--gpu-ids 0 2 3 \\
|
||||
--push-to-hub
|
||||
|
||||
# List available cameras:
|
||||
python subtask_annotation.py \\
|
||||
--repo-id pepijn223/mydataset \\
|
||||
--subtasks "reach,grasp" \\
|
||||
--max-episodes 0
|
||||
|
||||
Performance Tips:
|
||||
Performance remarks:
|
||||
- Each worker loads one model instance on its assigned GPU
|
||||
- The 30B model requires ~60GB VRAM per GPU
|
||||
- Use --num-workers N for N GPUs to get N× speedup
|
||||
- Episodes are distributed round-robin across workers
|
||||
- Use --num-workers N for N GPUs
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -885,10 +878,7 @@ Performance Tips:
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse subtask list
|
||||
subtask_list = [s.strip() for s in args.subtasks.split(",")]
|
||||
|
||||
# Parse dtype
|
||||
dtype_map = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
@@ -906,11 +896,9 @@ Performance Tips:
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# Load dataset
|
||||
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||||
|
||||
# Get FPS from dataset
|
||||
fps = dataset.fps
|
||||
console.print(f"[cyan]Dataset FPS: {fps}[/cyan]")
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ Example usage:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -38,12 +39,17 @@ import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.sarm_utils import pad_state_to_max_dim
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
pad_state_to_max_dim,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
from lerobot.datasets.utils import load_stats
|
||||
|
||||
|
||||
@@ -328,6 +334,110 @@ def run_inference(
|
||||
return np.array(all_progress), np.array(all_stages)
|
||||
|
||||
|
||||
def compute_ground_truth_progress(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
temporal_proportions: dict[str, float],
|
||||
subtask_names_ordered: list[str],
|
||||
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
|
||||
"""
|
||||
Compute ground truth progress and stage labels for an episode using annotations.
|
||||
|
||||
Uses SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode
|
||||
temporal_proportions: Dict mapping subtask name to proportion
|
||||
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
|
||||
|
||||
Returns:
|
||||
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
|
||||
"""
|
||||
# Load episode metadata
|
||||
episodes_df = dataset.meta.episodes.to_pandas()
|
||||
|
||||
# Check if annotations exist
|
||||
if "subtask_names" not in episodes_df.columns:
|
||||
logger.warning("No subtask_names column found in episodes metadata")
|
||||
return None, None
|
||||
|
||||
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
|
||||
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
|
||||
logger.warning(f"No annotations found for episode {episode_index}")
|
||||
return None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
|
||||
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Get temporal proportions as ordered list
|
||||
temporal_proportions_list = [
|
||||
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
|
||||
]
|
||||
|
||||
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
|
||||
logger.info(f"Subtask names in episode: {ep_subtask_names}")
|
||||
logger.info(f"Subtask start frames: {subtask_start_frames}")
|
||||
logger.info(f"Subtask end frames: {subtask_end_frames}")
|
||||
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
|
||||
|
||||
# Compute ground truth for each frame
|
||||
gt_progress = np.zeros(num_frames)
|
||||
gt_stages = np.zeros(num_frames, dtype=np.int32)
|
||||
|
||||
for frame_rel in range(num_frames):
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if frame_rel >= start_frame and frame_rel <= end_frame:
|
||||
# Found the subtask - get its global index
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
|
||||
|
||||
# Compute τ_t using utility function
|
||||
tau = compute_tau(frame_rel, start_frame, end_frame)
|
||||
|
||||
# Compute cumulative progress using utility function
|
||||
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
|
||||
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Handle frames outside annotated subtasks
|
||||
if frame_rel < subtask_start_frames[0]:
|
||||
gt_progress[frame_rel] = 0.0
|
||||
gt_stages[frame_rel] = 0
|
||||
elif frame_rel > subtask_end_frames[-1]:
|
||||
gt_progress[frame_rel] = 1.0
|
||||
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
|
||||
else:
|
||||
# Between subtasks - find previous subtask
|
||||
for j in range(len(ep_subtask_names) - 1):
|
||||
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
|
||||
name = ep_subtask_names[j]
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
|
||||
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
break
|
||||
|
||||
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
|
||||
return gt_progress, gt_stages
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
progress_predictions: np.ndarray,
|
||||
@@ -337,10 +447,12 @@ def visualize_predictions(
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8),
|
||||
subtask_names: list[str] | None = None,
|
||||
temporal_proportions: dict[str, float] | None = None
|
||||
temporal_proportions: dict[str, float] | None = None,
|
||||
ground_truth_progress: np.ndarray | None = None,
|
||||
ground_truth_stages: np.ndarray | None = None,
|
||||
):
|
||||
"""
|
||||
Create visualization of SARM predictions.
|
||||
Create visualization of SARM predictions with optional ground truth comparison.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
@@ -348,11 +460,12 @@ def visualize_predictions(
|
||||
stage_predictions: Stage probabilities (num_frames, num_stages)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
show_frames: Whether to include sample frames
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
subtask_names: Optional list of subtask names for labeling
|
||||
temporal_proportions: Optional dict of temporal proportions for each subtask
|
||||
ground_truth_progress: Optional ground truth progress array (num_frames,)
|
||||
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
|
||||
"""
|
||||
num_stages = stage_predictions.shape[1]
|
||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
@@ -376,10 +489,16 @@ def visualize_predictions(
|
||||
# Plot 1: Progress over time
|
||||
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
|
||||
|
||||
# Plot ground truth if available
|
||||
if ground_truth_progress is not None:
|
||||
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
|
||||
linestyle='--', label='Ground Truth Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
|
||||
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'SARM Task Progress & Stage Prediction\nTask: "{task_description}"',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
@@ -391,6 +510,11 @@ def visualize_predictions(
|
||||
f'Max Progress: {progress_predictions.max():.3f}\n'
|
||||
f'Mean Progress: {progress_predictions.mean():.3f}'
|
||||
)
|
||||
if ground_truth_progress is not None:
|
||||
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
|
||||
stats_text += f'\nMSE vs GT: {mse:.4f}'
|
||||
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
|
||||
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
@@ -398,6 +522,21 @@ def visualize_predictions(
|
||||
# Plot 2: Stage predictions (stacked area plot)
|
||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||
colors=stage_colors, alpha=0.8, labels=stage_labels)
|
||||
|
||||
# Plot ground truth stage as vertical bands or markers
|
||||
if ground_truth_stages is not None:
|
||||
# Find stage transition points in ground truth
|
||||
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
|
||||
for change_idx in stage_changes:
|
||||
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
|
||||
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
|
||||
|
||||
# Add small markers at bottom showing GT stage
|
||||
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
|
||||
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
|
||||
c=[stage_colors[s] for s in ground_truth_stages[::30]],
|
||||
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
|
||||
|
||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||
ax_stages.set_ylim(0, 1)
|
||||
@@ -540,20 +679,42 @@ def main():
|
||||
subtask_names = model.config.subtask_names
|
||||
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
|
||||
|
||||
# Try to load temporal proportions from model's dataset meta
|
||||
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
|
||||
if 'temporal_proportions' in model.dataset_stats:
|
||||
temporal_proportions = model.dataset_stats['temporal_proportions']
|
||||
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
|
||||
# Try to load temporal proportions from model config
|
||||
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
|
||||
temporal_proportions = {
|
||||
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
|
||||
}
|
||||
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
|
||||
|
||||
# # Fallback: try to load from dataset meta
|
||||
# if temporal_proportions is None and subtask_names is not None:
|
||||
# import json
|
||||
# proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
# if proportions_path.exists():
|
||||
# with open(proportions_path, 'r') as f:
|
||||
# temporal_proportions = json.load(f)
|
||||
# logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
# Fallback: try to load from dataset meta
|
||||
if temporal_proportions is None:
|
||||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
if proportions_path.exists():
|
||||
with open(proportions_path, 'r') as f:
|
||||
temporal_proportions = json.load(f)
|
||||
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
|
||||
# Also extract subtask names from proportions if not already set
|
||||
if subtask_names is None:
|
||||
subtask_names = sorted(temporal_proportions.keys())
|
||||
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
|
||||
|
||||
# Compute ground truth progress if annotations are available
|
||||
ground_truth_progress = None
|
||||
ground_truth_stages = None
|
||||
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
logger.info("Attempting to compute ground truth progress from annotations...")
|
||||
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
|
||||
dataset,
|
||||
args.episode_index,
|
||||
temporal_proportions,
|
||||
subtask_names
|
||||
)
|
||||
if ground_truth_progress is None:
|
||||
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
|
||||
else:
|
||||
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
@@ -567,11 +728,20 @@ def main():
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize),
|
||||
subtask_names=subtask_names,
|
||||
temporal_proportions=temporal_proportions
|
||||
temporal_proportions=temporal_proportions,
|
||||
ground_truth_progress=ground_truth_progress,
|
||||
ground_truth_stages=ground_truth_stages,
|
||||
)
|
||||
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
|
||||
np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions)
|
||||
save_dict = {
|
||||
'progress': progress_predictions,
|
||||
'stages': stage_predictions
|
||||
}
|
||||
if ground_truth_progress is not None:
|
||||
save_dict['gt_progress'] = ground_truth_progress
|
||||
save_dict['gt_stages'] = ground_truth_stages
|
||||
np.savez(predictions_path, **save_dict)
|
||||
logger.info(f"Saved predictions to {predictions_path}")
|
||||
logger.info(f"\nVisualization: {output_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user