From 2889c0650abe4a2a6803747168a5b31d0d66dd8e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 27 Nov 2025 14:14:52 +0100 Subject: [PATCH] use task from dataset, cleanup visualizer --- scripts/visualize_sarm_predictions.py | 153 ++++++------------ .../policies/sarm/configuration_sarm.py | 1 - src/lerobot/policies/sarm/processor_sarm.py | 18 ++- 3 files changed, 61 insertions(+), 111 deletions(-) diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index eeb81d07e..725f76a53 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -265,13 +265,11 @@ def run_inference( state_slices = [] for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): - # Compute frame indices using SARM pattern: - # [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t] - # The first delta is -100000 which clamps to 0 (episode start) + # Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t] + # The first delta is -100000 which clamps to episode start deltas = model.config.observation_delta_indices frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas] - - # Extract slice + video_slice = video_embeddings[frame_indices] video_slices.append(video_slice) @@ -319,7 +317,6 @@ def visualize_predictions( stage_predictions: np.ndarray, task_description: str, output_path: Path, - show_frames: bool = False, num_sample_frames: int = 8, figsize: tuple = (14, 8), subtask_names: list[str] | None = None, @@ -357,13 +354,6 @@ def visualize_predictions( ax_progress = fig.add_subplot(gs[0]) ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) ax_frames = fig.add_subplot(gs[2]) - else: - # Just progress and stage plots - fig = plt.figure(figsize=figsize) - gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3) - - ax_progress = fig.add_subplot(gs[0]) - ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) frame_indices = np.arange(len(progress_predictions)) @@ -422,48 +412,46 @@ def visualize_predictions( cumulative_progress = stage_end_progress # Plot 3: Sample frames (if requested) - if show_frames: - frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) - - ax_frames.axis('off') - - # Create grid for frames - frame_height = frames[0].shape[0] - frame_width = frames[0].shape[1] - - combined_width = frame_width * num_sample_frames - combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) - - for i, frame_idx in enumerate(frame_indices_to_show): - frame = frames[frame_idx] - if frame.shape[-1] == 1: - frame = np.repeat(frame, 3, axis=-1) - - # Add frame to combined image - x_start = i * frame_width - x_end = (i + 1) * frame_width - combined_image[:, x_start:x_end] = frame - - # Add frame number, progress, and stage - progress_val = progress_predictions[frame_idx] - stage_idx = np.argmax(stage_predictions[frame_idx]) - stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}' - - # Truncate long stage names for display - if len(stage_name) > 15: - stage_name = stage_name[:12] + '...' - - label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}' - - # Draw label on image - ax_frames.text(x_start + frame_width / 2, -10, label, - ha='center', va='top', fontsize=7, - bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) - - ax_frames.imshow(combined_image) - ax_frames.set_title('Sample Frames', fontsize=12, pad=20) + frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) + + ax_frames.axis('off') + + # Create grid for frames + frame_height = frames[0].shape[0] + frame_width = frames[0].shape[1] + + combined_width = frame_width * num_sample_frames + combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) + + for i, frame_idx in enumerate(frame_indices_to_show): + frame = frames[frame_idx] + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=-1) + + # Add frame to combined image + x_start = i * frame_width + x_end = (i + 1) * frame_width + combined_image[:, x_start:x_end] = frame + + # Add frame number, progress, and stage + progress_val = progress_predictions[frame_idx] + stage_idx = np.argmax(stage_predictions[frame_idx]) + stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}' + + # Truncate long stage names for display + if len(stage_name) > 15: + stage_name = stage_name[:12] + '...' + + label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}' + + # Draw label on image + ax_frames.text(x_start + frame_width / 2, -10, label, + ha='center', va='top', fontsize=7, + bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) + + ax_frames.imshow(combined_image) + ax_frames.set_title('Sample Frames', fontsize=12, pad=20) - # Save figure plt.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=150, bbox_inches='tight') @@ -501,7 +489,6 @@ def main(): f"Dataset has {len(dataset.meta.episodes)} episodes." ) - # Determine which image key to use image_key = args.image_key if args.image_key is not None else model.config.image_key logger.info(f"Using image key: {image_key}") @@ -531,16 +518,15 @@ def main(): temporal_proportions = model.dataset_stats['temporal_proportions'] logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}") - # Fallback: try to load from dataset meta - if temporal_proportions is None and subtask_names is not None: - import json - proportions_path = dataset.root / "meta" / "temporal_proportions.json" - if proportions_path.exists(): - with open(proportions_path, 'r') as f: - temporal_proportions = json.load(f) - logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") + # # Fallback: try to load from dataset meta + # if temporal_proportions is None and subtask_names is not None: + # import json + # proportions_path = dataset.root / "meta" / "temporal_proportions.json" + # if proportions_path.exists(): + # with open(proportions_path, 'r') as f: + # temporal_proportions = json.load(f) + # logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") - # Create visualization output_dir = Path(args.output_dir) output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" @@ -550,55 +536,16 @@ def main(): stage_predictions, task_description, output_path, - show_frames=args.show_frames, num_sample_frames=args.num_sample_frames, figsize=tuple(args.figsize), subtask_names=subtask_names, temporal_proportions=temporal_proportions ) - # Save predictions as numpy arrays predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz" np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions) logger.info(f"Saved predictions to {predictions_path}") - - # Print summary - logger.info("\n" + "="*60) - logger.info("INFERENCE SUMMARY") - logger.info("="*60) - logger.info(f"Model: {args.model_id}") - logger.info(f"Dataset: {args.dataset_repo}") - logger.info(f"Episode: {args.episode_index}") - logger.info(f"Task: {task_description}") - logger.info(f"Frames: {len(frames)}") - logger.info(f"Final Progress: {progress_predictions[-1]:.3f}") - logger.info(f"Max Progress: {progress_predictions.max():.3f}") - logger.info(f"Mean Progress: {progress_predictions.mean():.3f}") - - # Show most common stage with name if available - most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0)) - if subtask_names is not None and most_common_stage_idx < len(subtask_names): - most_common_stage_name = subtask_names[most_common_stage_idx] - logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})") - else: - logger.info(f"Most Common Stage: {most_common_stage_idx + 1}") - - # Show subtask breakdown - if subtask_names is not None: - logger.info("\nSubtask Breakdown:") - total_frames = len(stage_predictions) - for i, name in enumerate(subtask_names): - # Calculate percentage of frames where this stage was dominant - dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i) - percentage = (dominant_frames / total_frames) * 100 - logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)") - - if temporal_proportions is not None and name in temporal_proportions: - expected_pct = temporal_proportions[name] * 100 - logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%") - logger.info(f"\nVisualization: {output_path}") - logger.info("="*60) if __name__ == "__main__": diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 16ba5d6da..0fa2e0b85 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -55,7 +55,6 @@ class SARMConfig(PreTrainedConfig): # Processor settings image_key: str = "observation.images.top" # Key for image used from the dataset - task_description: str = "perform the task" # Default task description # State key in the dataset (for normalization) state_key: str = "observation.state" diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index b820c4966..45e2f6821 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -49,14 +49,12 @@ class SARMEncodingProcessorStep(ProcessorStep): self, config: SARMConfig, image_key: str | None = None, - task_description: str | None = None, dataset_meta = None, dataset_stats: dict | None = None, ): super().__init__() self.config = config self.image_key = image_key or config.image_key - self.task_description = task_description or config.task_description self.dataset_meta = dataset_meta self.dataset_stats = dataset_stats self.temporal_proportions = None @@ -442,15 +440,21 @@ class SARMEncodingProcessorStep(ProcessorStep): # Pad state observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) - # Encode text with CLIP - batch_size = video_features.shape[0] - observation['text_features'] = self._encode_text_clip(self.task_description, batch_size) - - # Extract frame/episode indices from complementary data + # Extract complementary data (includes task from dataset) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if not isinstance(comp_data, dict): raise ValueError("COMPLEMENTARY_DATA must be a dictionary") + # Get task description from dataset (complementary_data["task"]) + task = comp_data.get('task') + if isinstance(task, list): + # If batch, take first task (assuming same task for all items in batch) + task = task[0] if task else "" + + # Encode text with CLIP + batch_size = video_features.shape[0] + observation['text_features'] = self._encode_text_clip(task, batch_size) + frame_index = comp_data.get('index') episode_index = comp_data.get('episode_index')