mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
use task from dataset, cleanup visualizer
This commit is contained in:
@@ -265,13 +265,11 @@ def run_inference(
|
||||
state_slices = []
|
||||
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices using SARM pattern:
|
||||
# [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
# The first delta is -100000 which clamps to 0 (episode start)
|
||||
# Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
# The first delta is -100000 which clamps to episode start
|
||||
deltas = model.config.observation_delta_indices
|
||||
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
|
||||
|
||||
# Extract slice
|
||||
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
video_slices.append(video_slice)
|
||||
|
||||
@@ -319,7 +317,6 @@ def visualize_predictions(
|
||||
stage_predictions: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
show_frames: bool = False,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8),
|
||||
subtask_names: list[str] | None = None,
|
||||
@@ -357,13 +354,6 @@ def visualize_predictions(
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
ax_frames = fig.add_subplot(gs[2])
|
||||
else:
|
||||
# Just progress and stage plots
|
||||
fig = plt.figure(figsize=figsize)
|
||||
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
|
||||
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
|
||||
frame_indices = np.arange(len(progress_predictions))
|
||||
|
||||
@@ -422,48 +412,46 @@ def visualize_predictions(
|
||||
cumulative_progress = stage_end_progress
|
||||
|
||||
# Plot 3: Sample frames (if requested)
|
||||
if show_frames:
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||
|
||||
# Truncate long stage names for display
|
||||
if len(stage_name) > 15:
|
||||
stage_name = stage_name[:12] + '...'
|
||||
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=7,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||
|
||||
# Truncate long stage names for display
|
||||
if len(stage_name) > 15:
|
||||
stage_name = stage_name[:12] + '...'
|
||||
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=7,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
# Save figure
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
@@ -501,7 +489,6 @@ def main():
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
# Determine which image key to use
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
|
||||
@@ -531,16 +518,15 @@ def main():
|
||||
temporal_proportions = model.dataset_stats['temporal_proportions']
|
||||
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
|
||||
|
||||
# Fallback: try to load from dataset meta
|
||||
if temporal_proportions is None and subtask_names is not None:
|
||||
import json
|
||||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
if proportions_path.exists():
|
||||
with open(proportions_path, 'r') as f:
|
||||
temporal_proportions = json.load(f)
|
||||
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
# # Fallback: try to load from dataset meta
|
||||
# if temporal_proportions is None and subtask_names is not None:
|
||||
# import json
|
||||
# proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
# if proportions_path.exists():
|
||||
# with open(proportions_path, 'r') as f:
|
||||
# temporal_proportions = json.load(f)
|
||||
# logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
|
||||
# Create visualization
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
|
||||
@@ -550,55 +536,16 @@ def main():
|
||||
stage_predictions,
|
||||
task_description,
|
||||
output_path,
|
||||
show_frames=args.show_frames,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize),
|
||||
subtask_names=subtask_names,
|
||||
temporal_proportions=temporal_proportions
|
||||
)
|
||||
|
||||
# Save predictions as numpy arrays
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
|
||||
np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions)
|
||||
logger.info(f"Saved predictions to {predictions_path}")
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("INFERENCE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Model: {args.model_id}")
|
||||
logger.info(f"Dataset: {args.dataset_repo}")
|
||||
logger.info(f"Episode: {args.episode_index}")
|
||||
logger.info(f"Task: {task_description}")
|
||||
logger.info(f"Frames: {len(frames)}")
|
||||
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
|
||||
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
|
||||
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
|
||||
|
||||
# Show most common stage with name if available
|
||||
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
|
||||
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
|
||||
most_common_stage_name = subtask_names[most_common_stage_idx]
|
||||
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
|
||||
else:
|
||||
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
|
||||
|
||||
# Show subtask breakdown
|
||||
if subtask_names is not None:
|
||||
logger.info("\nSubtask Breakdown:")
|
||||
total_frames = len(stage_predictions)
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Calculate percentage of frames where this stage was dominant
|
||||
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
|
||||
percentage = (dominant_frames / total_frames) * 100
|
||||
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
|
||||
|
||||
if temporal_proportions is not None and name in temporal_proportions:
|
||||
expected_pct = temporal_proportions[name] * 100
|
||||
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
|
||||
|
||||
logger.info(f"\nVisualization: {output_path}")
|
||||
logger.info("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user