use task from dataset, cleanup visualizer

This commit is contained in:
Pepijn
2025-11-27 14:14:52 +01:00
parent f2ad86831d
commit 2889c0650a
3 changed files with 61 additions and 111 deletions
+50 -103
View File
@@ -265,13 +265,11 @@ def run_inference(
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using SARM pattern:
# [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
# The first delta is -100000 which clamps to 0 (episode start)
# Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
# The first delta is -100000 which clamps to episode start
deltas = model.config.observation_delta_indices
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
# Extract slice
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
@@ -319,7 +317,6 @@ def visualize_predictions(
stage_predictions: np.ndarray,
task_description: str,
output_path: Path,
show_frames: bool = False,
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
@@ -357,13 +354,6 @@ def visualize_predictions(
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
ax_frames = fig.add_subplot(gs[2])
else:
# Just progress and stage plots
fig = plt.figure(figsize=figsize)
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
frame_indices = np.arange(len(progress_predictions))
@@ -422,48 +412,46 @@ def visualize_predictions(
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
if show_frames:
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
# Save figure
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
@@ -501,7 +489,6 @@ def main():
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
# Determine which image key to use
image_key = args.image_key if args.image_key is not None else model.config.image_key
logger.info(f"Using image key: {image_key}")
@@ -531,16 +518,15 @@ def main():
temporal_proportions = model.dataset_stats['temporal_proportions']
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None and subtask_names is not None:
import json
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# # Fallback: try to load from dataset meta
# if temporal_proportions is None and subtask_names is not None:
# import json
# proportions_path = dataset.root / "meta" / "temporal_proportions.json"
# if proportions_path.exists():
# with open(proportions_path, 'r') as f:
# temporal_proportions = json.load(f)
# logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Create visualization
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
@@ -550,55 +536,16 @@ def main():
stage_predictions,
task_description,
output_path,
show_frames=args.show_frames,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions
)
# Save predictions as numpy arrays
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions)
logger.info(f"Saved predictions to {predictions_path}")
# Print summary
logger.info("\n" + "="*60)
logger.info("INFERENCE SUMMARY")
logger.info("="*60)
logger.info(f"Model: {args.model_id}")
logger.info(f"Dataset: {args.dataset_repo}")
logger.info(f"Episode: {args.episode_index}")
logger.info(f"Task: {task_description}")
logger.info(f"Frames: {len(frames)}")
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
# Show most common stage with name if available
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
most_common_stage_name = subtask_names[most_common_stage_idx]
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
else:
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
# Show subtask breakdown
if subtask_names is not None:
logger.info("\nSubtask Breakdown:")
total_frames = len(stage_predictions)
for i, name in enumerate(subtask_names):
# Calculate percentage of frames where this stage was dominant
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
percentage = (dominant_frames / total_frames) * 100
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
if temporal_proportions is not None and name in temporal_proportions:
expected_pct = temporal_proportions[name] * 100
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
logger.info(f"\nVisualization: {output_path}")
logger.info("="*60)
if __name__ == "__main__":
@@ -55,7 +55,6 @@ class SARMConfig(PreTrainedConfig):
# Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset
task_description: str = "perform the task" # Default task description
# State key in the dataset (for normalization)
state_key: str = "observation.state"
+11 -7
View File
@@ -49,14 +49,12 @@ class SARMEncodingProcessorStep(ProcessorStep):
self,
config: SARMConfig,
image_key: str | None = None,
task_description: str | None = None,
dataset_meta = None,
dataset_stats: dict | None = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.task_description = task_description or config.task_description
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.temporal_proportions = None
@@ -442,15 +440,21 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Pad state
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
# Extract frame/episode indices from complementary data
# Extract complementary data (includes task from dataset)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
# Get task description from dataset (complementary_data["task"])
task = comp_data.get('task')
if isinstance(task, list):
# If batch, take first task (assuming same task for all items in batch)
task = task[0] if task else ""
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(task, batch_size)
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')