#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Inference script for SARM (Stage-Aware Reward Model). This script loads a trained SARM model and runs inference on a dataset episode, generating visualizations of the predicted task stages and progress over time. Example usage: python scripts/visualize_sarm_predictions.py \ --model-id username/sarm-model \ --dataset-repo lerobot/aloha_sim_insertion_human \ --episode-index 0 \ --output-dir outputs/sarm_viz \ --task-description "insert the peg into the socket" """ import argparse import json import logging from pathlib import Path from typing import Optional import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches import numpy as np import pandas as pd import torch from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.sarm.modeling_sarm import SARMRewardModel from lerobot.policies.sarm.sarm_utils import ( pad_state_to_max_dim, compute_tau, compute_cumulative_progress_batch, ) from lerobot.datasets.utils import load_stats logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions") # Model arguments parser.add_argument( "--model-id", type=str, required=True, help="HuggingFace model ID or local path to trained SARM model" ) # Dataset arguments parser.add_argument( "--dataset-repo", type=str, required=True, help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)" ) parser.add_argument( "--episode-index", type=int, default=0, help="Index of the episode to visualize (default: 0)" ) parser.add_argument( "--task-description", type=str, default="perform the task", help="Task description for the reward model (default: 'perform the task')" ) # Output arguments parser.add_argument( "--output-dir", type=str, default="outputs/sarm_inference", help="Directory to save visualization outputs (default: outputs/sarm_inference)" ) parser.add_argument( "--image-key", type=str, default=None, help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key" ) parser.add_argument( "--state-key", type=str, default=None, help="Key for joint states in dataset. If None, auto-detects from dataset" ) # Visualization options parser.add_argument( "--show-frames", action="store_true", help="Include sample frames in the visualization" ) parser.add_argument( "--num-sample-frames", type=int, default=8, help="Number of sample frames to show (default: 8)" ) parser.add_argument( "--figsize", type=int, nargs=2, default=[14, 8], help="Figure size as width height (default: 14 8)" ) # Device parser.add_argument( "--device", type=str, default=None, help="Device to run inference on (cuda/cpu, default: auto-detect)" ) return parser.parse_args() def load_episode_data( dataset: LeRobotDataset, episode_index: int, image_key: str, state_key: str | None = None ) -> tuple[np.ndarray, np.ndarray, int, int, str]: """ Load all frames and states from a specific episode. Args: dataset: LeRobotDataset instance episode_index: Index of the episode to load image_key: Key for accessing images in the dataset state_key: Key for accessing joint states (auto-detected if None) Returns: Tuple of (frames, states, start_index, end_index, task_description) """ # Get episode boundaries episode_data = dataset.meta.episodes start_idx = episode_data["dataset_from_index"][episode_index] end_idx = episode_data["dataset_to_index"][episode_index] logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)") # Auto-detect state key if not provided if state_key is None: first_item = dataset[start_idx] state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()] if state_keys: state_key = state_keys[0] logger.info(f"Auto-detected state key: {state_key}") # Get task description from the dataset if available task_description = None first_item = dataset[start_idx] if "task" in first_item: task_description = first_item["task"] logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'") # Load all frames and states from the episode frames = [] states = [] for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"): item = dataset[idx] # Get image img = item[image_key] # Convert to numpy if needed if isinstance(img, torch.Tensor): img = img.cpu().numpy() # Handle different image formats (C, H, W) or (H, W, C) if img.shape[0] in [1, 3]: # Channel first img = np.transpose(img, (1, 2, 0)) # Convert to uint8 if needed if img.dtype != np.uint8: if img.max() <= 1.0: img = (img * 255).astype(np.uint8) else: img = img.astype(np.uint8) frames.append(img) # Get state if available if state_key and state_key in item: state = item[state_key] if isinstance(state, torch.Tensor): state = state.cpu().numpy() states.append(state) frames = np.array(frames) states = np.array(states) if states else None logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}") if states is not None: logger.info(f"Loaded states with shape {states.shape}") return frames, states, start_idx, end_idx, task_description @torch.no_grad() def run_inference( model: SARMRewardModel, frames: np.ndarray, states: Optional[np.ndarray], task_description: str, dataset_stats: dict | None = None, state_key: str = "observation.state", batch_size: int = 32 ) -> tuple[np.ndarray, np.ndarray]: """ Run SARM inference on video frames and joint states. (per SARM paper Section A.4): - Frame 0: Initial frame of the episode (frame 0) - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t] Args: model: SARM model frames: Video frames (num_frames, H, W, C) - all frames from ONE episode states: Joint states (num_frames, state_dim) task_description: Task description text dataset_stats: Dataset statistics for state normalization (same as training) state_key: Key for state in dataset_stats batch_size: Batch size for processing slices Returns: Tuple of (progress_predictions, stage_predictions) - progress_predictions: (num_frames,) - stage_predictions: (num_frames, num_stages) """ logger.info("Encoding video frames with CLIP...") video_embeddings = model.encode_images(frames) logger.info("Encoding task description with CLIP...") text_embedding = model.encode_text(task_description) # Get config values num_frames_model = model.config.num_frames # 9 frame_gap = model.config.frame_gap # 30 logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...") # Convert to tensors video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) text_embedding = torch.tensor(text_embedding, dtype=torch.float32) if states is not None: state_embeddings = torch.tensor(states, dtype=torch.float32) # Normalize states using dataset stats (same as training processor) if dataset_stats is not None and state_key in dataset_stats: mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32) std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32) state_embeddings = (state_embeddings - mean) / (std + 1e-8) logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}") else: logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)") else: state_embeddings = None video_slices = [] state_slices = [] for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): # Compute frame indices using symmetric bidirectional pattern: # [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] # Boundary handling: clamp to [0, last_valid] deltas = model.config.observation_delta_indices last_valid = len(video_embeddings) - 1 frame_indices = [] for delta in deltas: idx = current_frame + delta idx = max(0, min(idx, last_valid)) # Clamp to valid range frame_indices.append(idx) video_slice = video_embeddings[frame_indices] video_slices.append(video_slice) if state_embeddings is not None: state_slice = state_embeddings[frame_indices] state_slices.append(state_slice) video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512) if state_embeddings is not None: state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim) # Pad states to max_state_dim (same as training processor) state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim) else: state_slices = None logger.info("Running SARM inference on all slices...") # Process in batches all_progress = [] all_stages = [] for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"): batch_video = video_slices[i:i + batch_size].to(model.device) batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None batch_size_actual = batch_video.shape[0] # Replicate text embedding for batch batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device) # Get predictions stage_logits, stage_probs, progress_preds = model.sarm_transformer( batch_video, batch_text, batch_states ) # Extract predictions at the "current frame" position # With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g], # the current frame is at position 5 (0-indexed) current_frame_idx = 5 batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy() batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy() all_progress.extend(batch_progress) all_stages.extend(batch_stages) return np.array(all_progress), np.array(all_stages) def compute_ground_truth_progress( dataset: LeRobotDataset, episode_index: int, temporal_proportions: dict[str, float], subtask_names_ordered: list[str], ) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: """ Compute ground truth progress and stage labels for an episode using annotations. Uses SARM Paper Formula (2): y_t = P_{k-1} + ᾱ_k × τ_t where: - τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress - P_{k-1} is cumulative prior (sum of previous subtask proportions) - ᾱ_k is the temporal proportion for subtask k Args: dataset: LeRobotDataset instance episode_index: Index of the episode temporal_proportions: Dict mapping subtask name to proportion subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing) Returns: Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations """ # Load episode metadata episodes_df = dataset.meta.episodes.to_pandas() # Check if annotations exist if "subtask_names" not in episodes_df.columns: logger.warning("No subtask_names column found in episodes metadata") return None, None ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"] if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)): logger.warning(f"No annotations found for episode {episode_index}") return None, None subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"] subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"] # Get episode boundaries ep_start = dataset.meta.episodes["dataset_from_index"][episode_index] ep_end = dataset.meta.episodes["dataset_to_index"][episode_index] num_frames = ep_end - ep_start # Get temporal proportions as ordered list temporal_proportions_list = [ temporal_proportions.get(name, 0.0) for name in subtask_names_ordered ] logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks") logger.info(f"Subtask names in episode: {ep_subtask_names}") logger.info(f"Subtask start frames: {subtask_start_frames}") logger.info(f"Subtask end frames: {subtask_end_frames}") logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}") # Compute ground truth for each frame gt_progress = np.zeros(num_frames) gt_stages = np.zeros(num_frames, dtype=np.int32) for frame_rel in range(num_frames): # Find which subtask this frame belongs to found = False for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)): if frame_rel >= start_frame and frame_rel <= end_frame: # Found the subtask - get its global index stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0 # Compute τ_t using utility function tau = compute_tau(frame_rel, start_frame, end_frame) # Compute cumulative progress using utility function progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list) gt_progress[frame_rel] = progress gt_stages[frame_rel] = stage_idx found = True break if not found: # Handle frames outside annotated subtasks if frame_rel < subtask_start_frames[0]: gt_progress[frame_rel] = 0.0 gt_stages[frame_rel] = 0 elif frame_rel > subtask_end_frames[-1]: gt_progress[frame_rel] = 1.0 gt_stages[frame_rel] = len(subtask_names_ordered) - 1 else: # Between subtasks - find previous subtask for j in range(len(ep_subtask_names) - 1): if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]: name = ep_subtask_names[j] stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list) gt_progress[frame_rel] = progress gt_stages[frame_rel] = stage_idx break logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}") return gt_progress, gt_stages def visualize_predictions( frames: np.ndarray, progress_predictions: np.ndarray, stage_predictions: np.ndarray, task_description: str, output_path: Path, num_sample_frames: int = 8, figsize: tuple = (14, 8), subtask_names: list[str] | None = None, temporal_proportions: dict[str, float] | None = None, ground_truth_progress: np.ndarray | None = None, ground_truth_stages: np.ndarray | None = None, ): """ Create visualization of SARM predictions with optional ground truth comparison. Args: frames: Video frames (num_frames, H, W, C) progress_predictions: Progress predictions (num_frames,) stage_predictions: Stage probabilities (num_frames, num_stages) task_description: Task description output_path: Path to save the figure num_sample_frames: Number of frames to show figsize: Figure size (width, height) subtask_names: Optional list of subtask names for labeling temporal_proportions: Optional dict of temporal proportions for each subtask ground_truth_progress: Optional ground truth progress array (num_frames,) ground_truth_stages: Optional ground truth stage indices array (num_frames,) """ num_stages = stage_predictions.shape[1] stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) # Use subtask names if available, otherwise use generic labels if subtask_names is not None and len(subtask_names) == num_stages: stage_labels = subtask_names else: stage_labels = [f'Stage {i+1}' for i in range(num_stages)] # Create figure with progress plot, stage plot, and sample frames fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) ax_progress = fig.add_subplot(gs[0]) ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) ax_frames = fig.add_subplot(gs[2]) frame_indices = np.arange(len(progress_predictions)) # Plot 1: Progress over time ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress') ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB') # Plot ground truth if available if ground_truth_progress is not None: ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745', linestyle='--', label='Ground Truth Progress') ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745') ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1) ax_progress.set_ylabel('Task Progress', fontsize=12) ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold') ax_progress.grid(True, alpha=0.3) ax_progress.set_ylim(-0.05, 1.1) ax_progress.legend(loc='upper left') # Add statistics box stats_text = ( f'Frames: {len(progress_predictions)}\n' f'Final Progress: {progress_predictions[-1]:.3f}\n' f'Max Progress: {progress_predictions.max():.3f}\n' f'Mean Progress: {progress_predictions.mean():.3f}' ) if ground_truth_progress is not None: mse = np.mean((progress_predictions - ground_truth_progress) ** 2) stats_text += f'\nMSE vs GT: {mse:.4f}' stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}' ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes, fontsize=10, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) # Plot 2: Stage predictions (stacked area plot) ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)], colors=stage_colors, alpha=0.8, labels=stage_labels) # Plot ground truth stage as vertical bands or markers if ground_truth_stages is not None: # Find stage transition points in ground truth stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1 for change_idx in stage_changes: ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5) ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1) # Add small markers at bottom showing GT stage gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1) ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02, c=[stage_colors[s] for s in ground_truth_stages[::30]], s=20, marker='|', alpha=0.8, label='GT Stage Markers') ax_stages.set_xlabel('Frame Index', fontsize=12) ax_stages.set_ylabel('Stage Probability', fontsize=12) ax_stages.set_ylim(0, 1) ax_stages.grid(True, alpha=0.3) # Adjust legend based on number of stages and label lengths if num_stages <= 5: ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8) else: ax_stages.legend(loc='upper left', ncol=3, fontsize=7) # Add vertical lines and labels for expected stage transitions (if temporal proportions available) if temporal_proportions is not None and subtask_names is not None: cumulative_progress = 0.0 for i, name in enumerate(stage_labels): if name in temporal_proportions: # Find approximate frame where this stage should end stage_end_progress = cumulative_progress + temporal_proportions[name] # Find frame index closest to this progress progress_diffs = np.abs(progress_predictions - stage_end_progress) stage_end_frame = np.argmin(progress_diffs) # Draw vertical line ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1) ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1) cumulative_progress = stage_end_progress # Plot 3: Sample frames (if requested) frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) ax_frames.axis('off') # Create grid for frames frame_height = frames[0].shape[0] frame_width = frames[0].shape[1] combined_width = frame_width * num_sample_frames combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) for i, frame_idx in enumerate(frame_indices_to_show): frame = frames[frame_idx] if frame.shape[-1] == 1: frame = np.repeat(frame, 3, axis=-1) # Add frame to combined image x_start = i * frame_width x_end = (i + 1) * frame_width combined_image[:, x_start:x_end] = frame # Add frame number, progress, and stage progress_val = progress_predictions[frame_idx] stage_idx = np.argmax(stage_predictions[frame_idx]) stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}' # Truncate long stage names for display if len(stage_name) > 15: stage_name = stage_name[:12] + '...' label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}' # Draw label on image ax_frames.text(x_start + frame_width / 2, -10, label, ha='center', va='top', fontsize=7, bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) ax_frames.imshow(combined_image) ax_frames.set_title('Sample Frames', fontsize=12, pad=20) plt.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=150, bbox_inches='tight') logger.info(f"Saved visualization to {output_path}") plt.close() def main(): args = parse_args() # Setup device if args.device is None: device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device logger.info(f"Using device: {device}") # Load model logger.info(f"Loading SARM model from {args.model_id}...") model = SARMRewardModel.from_pretrained(args.model_id) model.to(device) model.eval() logger.info("Model loaded successfully") # Load dataset logger.info(f"Loading dataset {args.dataset_repo}...") dataset = LeRobotDataset(args.dataset_repo) logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames") # Validate episode index if args.episode_index >= len(dataset.meta.episodes): raise ValueError( f"Episode index {args.episode_index} out of range. " f"Dataset has {len(dataset.meta.episodes)} episodes." ) image_key = args.image_key if args.image_key is not None else model.config.image_key state_key = args.state_key if args.state_key is not None else model.config.state_key logger.info(f"Using image key: {image_key}") logger.info(f"Using state key: {state_key}") # Load dataset stats for state normalization (same as training) dataset_stats = load_stats(dataset.root) if dataset_stats: logger.info(f"✓ Loaded dataset stats from {dataset.root}") else: logger.warning("⚠ Could not load dataset stats - states will not be normalized") # Load episode data frames, states, start_idx, end_idx, dataset_task = load_episode_data( dataset, args.episode_index, image_key, state_key ) # Use task description from dataset if available, otherwise use command-line argument task_description = dataset_task if dataset_task is not None else args.task_description logger.info(f"Using task description: '{task_description}'") # Run inference progress_predictions, stage_predictions = run_inference( model, frames, states, task_description, dataset_stats=dataset_stats, state_key=state_key ) # Extract subtask names and temporal proportions from model config if available subtask_names = None temporal_proportions = None if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None: subtask_names = model.config.subtask_names logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}") # Try to load temporal proportions from model config if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None: temporal_proportions = { name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions) } logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}") # Fallback: try to load from dataset meta if temporal_proportions is None: proportions_path = dataset.root / "meta" / "temporal_proportions.json" if proportions_path.exists(): with open(proportions_path, 'r') as f: temporal_proportions = json.load(f) logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") # Also extract subtask names from proportions if not already set if subtask_names is None: subtask_names = sorted(temporal_proportions.keys()) logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}") # Compute ground truth progress if annotations are available ground_truth_progress = None ground_truth_stages = None if temporal_proportions is not None and subtask_names is not None: logger.info("Attempting to compute ground truth progress from annotations...") ground_truth_progress, ground_truth_stages = compute_ground_truth_progress( dataset, args.episode_index, temporal_proportions, subtask_names ) if ground_truth_progress is None: logger.warning("⚠ Ground truth not available - annotations may be missing for this episode") else: logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available") output_dir = Path(args.output_dir) output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" visualize_predictions( frames, progress_predictions, stage_predictions, task_description, output_path, num_sample_frames=args.num_sample_frames, figsize=tuple(args.figsize), subtask_names=subtask_names, temporal_proportions=temporal_proportions, ground_truth_progress=ground_truth_progress, ground_truth_stages=ground_truth_stages, ) predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz" save_dict = { 'progress': progress_predictions, 'stages': stage_predictions } if ground_truth_progress is not None: save_dict['gt_progress'] = ground_truth_progress save_dict['gt_stages'] = ground_truth_stages np.savez(predictions_path, **save_dict) logger.info(f"Saved predictions to {predictions_path}") logger.info(f"\nVisualization: {output_path}") if __name__ == "__main__": main()