diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 725f76a53..a929cf233 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -43,6 +43,8 @@ from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.sarm.modeling_sarm import SARMRewardModel +from lerobot.policies.sarm.sarm_utils import pad_state_to_max_dim +from lerobot.datasets.utils import load_stats logging.basicConfig(level=logging.INFO) @@ -219,6 +221,8 @@ def run_inference( frames: np.ndarray, states: Optional[np.ndarray], task_description: str, + dataset_stats: dict | None = None, + state_key: str = "observation.state", batch_size: int = 32 ) -> tuple[np.ndarray, np.ndarray]: """ @@ -234,6 +238,8 @@ def run_inference( frames: Video frames (num_frames, H, W, C) - all frames from ONE episode states: Joint states (num_frames, state_dim) task_description: Task description text + dataset_stats: Dataset statistics for state normalization (same as training) + state_key: Key for state in dataset_stats batch_size: Batch size for processing slices Returns: @@ -258,6 +264,15 @@ def run_inference( text_embedding = torch.tensor(text_embedding, dtype=torch.float32) if states is not None: state_embeddings = torch.tensor(states, dtype=torch.float32) + + # Normalize states using dataset stats (same as training processor) + if dataset_stats is not None and state_key in dataset_stats: + mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32) + std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32) + state_embeddings = (state_embeddings - mean) / (std + 1e-8) + logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}") + else: + logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)") else: state_embeddings = None @@ -280,6 +295,8 @@ def run_inference( video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512) if state_embeddings is not None: state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim) + # Pad states to max_state_dim (same as training processor) + state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim) else: state_slices = None @@ -346,14 +363,13 @@ def visualize_predictions( else: stage_labels = [f'Stage {i+1}' for i in range(num_stages)] - if show_frames: - # Create figure with progress plot, stage plot, and sample frames - fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) - gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) - - ax_progress = fig.add_subplot(gs[0]) - ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) - ax_frames = fig.add_subplot(gs[2]) + # Create figure with progress plot, stage plot, and sample frames + fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) + gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) + + ax_progress = fig.add_subplot(gs[0]) + ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) + ax_frames = fig.add_subplot(gs[2]) frame_indices = np.arange(len(progress_predictions)) @@ -490,11 +506,20 @@ def main(): ) image_key = args.image_key if args.image_key is not None else model.config.image_key + state_key = args.state_key if args.state_key is not None else model.config.state_key logger.info(f"Using image key: {image_key}") + logger.info(f"Using state key: {state_key}") + + # Load dataset stats for state normalization (same as training) + dataset_stats = load_stats(dataset.root) + if dataset_stats: + logger.info(f"✓ Loaded dataset stats from {dataset.root}") + else: + logger.warning("⚠ Could not load dataset stats - states will not be normalized") # Load episode data frames, states, start_idx, end_idx, dataset_task = load_episode_data( - dataset, args.episode_index, image_key, args.state_key + dataset, args.episode_index, image_key, state_key ) # Use task description from dataset if available, otherwise use command-line argument @@ -502,7 +527,10 @@ def main(): logger.info(f"Using task description: '{task_description}'") # Run inference - progress_predictions, stage_predictions = run_inference(model, frames, states, task_description) + progress_predictions, stage_predictions = run_inference( + model, frames, states, task_description, + dataset_stats=dataset_stats, state_key=state_key + ) # Extract subtask names and temporal proportions from model config if available subtask_names = None