This commit is contained in:
Pepijn
2025-11-18 15:28:40 +01:00
parent f688eb160b
commit 1ffdc6f49e
5 changed files with 624 additions and 56 deletions
+93 -7
View File
@@ -338,7 +338,9 @@ def visualize_predictions(
output_path: Path,
show_frames: bool = False,
num_sample_frames: int = 8,
figsize: tuple = (14, 8)
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None
):
"""
Create visualization of SARM predictions.
@@ -352,10 +354,18 @@ def visualize_predictions(
show_frames: Whether to include sample frames
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
if show_frames:
# Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
@@ -398,12 +408,35 @@ def visualize_predictions(
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)])
colors=stage_colors, alpha=0.8, labels=stage_labels)
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3)
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
if show_frames:
@@ -431,7 +464,13 @@ def visualize_predictions(
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}'
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
@@ -495,6 +534,29 @@ def main():
# Run inference
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model's dataset meta
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
if 'temporal_proportions' in model.dataset_stats:
temporal_proportions = model.dataset_stats['temporal_proportions']
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None and subtask_names is not None:
import json
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Create visualization
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
@@ -507,7 +569,9 @@ def main():
output_path,
show_frames=args.show_frames,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize)
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions
)
# Save predictions as numpy arrays
@@ -527,8 +591,30 @@ def main():
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}")
logger.info(f"Visualization: {output_path}")
# Show most common stage with name if available
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
most_common_stage_name = subtask_names[most_common_stage_idx]
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
else:
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
# Show subtask breakdown
if subtask_names is not None:
logger.info("\nSubtask Breakdown:")
total_frames = len(stage_predictions)
for i, name in enumerate(subtask_names):
# Calculate percentage of frames where this stage was dominant
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
percentage = (dominant_frames / total_frames) * 100
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
if temporal_proportions is not None and name in temporal_proportions:
expected_pct = temporal_proportions[name] * 100
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
logger.info(f"\nVisualization: {output_path}")
logger.info("="*60)