mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
subtasks
This commit is contained in:
@@ -338,7 +338,9 @@ def visualize_predictions(
|
||||
output_path: Path,
|
||||
show_frames: bool = False,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8)
|
||||
figsize: tuple = (14, 8),
|
||||
subtask_names: list[str] | None = None,
|
||||
temporal_proportions: dict[str, float] | None = None
|
||||
):
|
||||
"""
|
||||
Create visualization of SARM predictions.
|
||||
@@ -352,10 +354,18 @@ def visualize_predictions(
|
||||
show_frames: Whether to include sample frames
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
subtask_names: Optional list of subtask names for labeling
|
||||
temporal_proportions: Optional dict of temporal proportions for each subtask
|
||||
"""
|
||||
num_stages = stage_predictions.shape[1]
|
||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
|
||||
# Use subtask names if available, otherwise use generic labels
|
||||
if subtask_names is not None and len(subtask_names) == num_stages:
|
||||
stage_labels = subtask_names
|
||||
else:
|
||||
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
||||
|
||||
if show_frames:
|
||||
# Create figure with progress plot, stage plot, and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
@@ -398,12 +408,35 @@ def visualize_predictions(
|
||||
|
||||
# Plot 2: Stage predictions (stacked area plot)
|
||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||
colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)])
|
||||
colors=stage_colors, alpha=0.8, labels=stage_labels)
|
||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||
|
||||
# Adjust legend based on number of stages and label lengths
|
||||
if num_stages <= 5:
|
||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||
else:
|
||||
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
|
||||
|
||||
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
cumulative_progress = 0.0
|
||||
for i, name in enumerate(stage_labels):
|
||||
if name in temporal_proportions:
|
||||
# Find approximate frame where this stage should end
|
||||
stage_end_progress = cumulative_progress + temporal_proportions[name]
|
||||
|
||||
# Find frame index closest to this progress
|
||||
progress_diffs = np.abs(progress_predictions - stage_end_progress)
|
||||
stage_end_frame = np.argmin(progress_diffs)
|
||||
|
||||
# Draw vertical line
|
||||
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
|
||||
cumulative_progress = stage_end_progress
|
||||
|
||||
# Plot 3: Sample frames (if requested)
|
||||
if show_frames:
|
||||
@@ -431,7 +464,13 @@ def visualize_predictions(
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}'
|
||||
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||
|
||||
# Truncate long stage names for display
|
||||
if len(stage_name) > 15:
|
||||
stage_name = stage_name[:12] + '...'
|
||||
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
@@ -495,6 +534,29 @@ def main():
|
||||
# Run inference
|
||||
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
|
||||
|
||||
# Extract subtask names and temporal proportions from model config if available
|
||||
subtask_names = None
|
||||
temporal_proportions = None
|
||||
|
||||
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
|
||||
subtask_names = model.config.subtask_names
|
||||
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
|
||||
|
||||
# Try to load temporal proportions from model's dataset meta
|
||||
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
|
||||
if 'temporal_proportions' in model.dataset_stats:
|
||||
temporal_proportions = model.dataset_stats['temporal_proportions']
|
||||
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
|
||||
|
||||
# Fallback: try to load from dataset meta
|
||||
if temporal_proportions is None and subtask_names is not None:
|
||||
import json
|
||||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
if proportions_path.exists():
|
||||
with open(proportions_path, 'r') as f:
|
||||
temporal_proportions = json.load(f)
|
||||
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
|
||||
# Create visualization
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
@@ -507,7 +569,9 @@ def main():
|
||||
output_path,
|
||||
show_frames=args.show_frames,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize)
|
||||
figsize=tuple(args.figsize),
|
||||
subtask_names=subtask_names,
|
||||
temporal_proportions=temporal_proportions
|
||||
)
|
||||
|
||||
# Save predictions as numpy arrays
|
||||
@@ -527,8 +591,30 @@ def main():
|
||||
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
|
||||
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
|
||||
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
|
||||
logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}")
|
||||
logger.info(f"Visualization: {output_path}")
|
||||
|
||||
# Show most common stage with name if available
|
||||
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
|
||||
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
|
||||
most_common_stage_name = subtask_names[most_common_stage_idx]
|
||||
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
|
||||
else:
|
||||
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
|
||||
|
||||
# Show subtask breakdown
|
||||
if subtask_names is not None:
|
||||
logger.info("\nSubtask Breakdown:")
|
||||
total_frames = len(stage_predictions)
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Calculate percentage of frames where this stage was dominant
|
||||
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
|
||||
percentage = (dominant_frames / total_frames) * 100
|
||||
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
|
||||
|
||||
if temporal_proportions is not None and name in temporal_proportions:
|
||||
expected_pct = temporal_proportions[name] * 100
|
||||
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
|
||||
|
||||
logger.info(f"\nVisualization: {output_path}")
|
||||
logger.info("="*60)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user