mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
add sarm
This commit is contained in:
@@ -0,0 +1,528 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for ReWiND Reward Model.
|
||||
|
||||
This script loads a trained ReWiND model and runs inference on a dataset episode,
|
||||
generating visualizations of the predicted task progression over time.
|
||||
|
||||
Example usage:
|
||||
python scripts/visualize_rewind_predictions.py \
|
||||
--model-id username/rewind-model \
|
||||
--dataset-repo lerobot/aloha_sim_insertion_human \
|
||||
--episode-index 0 \
|
||||
--output-dir outputs/rewind_viz \
|
||||
--task-description "insert the peg into the socket"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to trained ReWiND model"
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the episode to visualize (default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-description",
|
||||
type=str,
|
||||
default="perform the task",
|
||||
help="Task description for the reward model (default: 'perform the task')"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/rewind_inference",
|
||||
help="Directory to save visualization outputs (default: outputs/rewind_inference)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key"
|
||||
)
|
||||
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--show-frames",
|
||||
action="store_true",
|
||||
help="Include sample frames in the visualization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-sample-frames",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of sample frames to show (default: 8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--figsize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[12, 6],
|
||||
help="Figure size as width height (default: 12 6)"
|
||||
)
|
||||
|
||||
# Device
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_episode_data(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
image_key: str
|
||||
) -> tuple[np.ndarray, int, int, str]:
|
||||
"""
|
||||
Load all frames from a specific episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode to load
|
||||
image_key: Key for accessing images in the dataset
|
||||
|
||||
Returns:
|
||||
Tuple of (frames, start_index, end_index, task_description)
|
||||
"""
|
||||
# Get episode boundaries
|
||||
episode_data = dataset.meta.episodes
|
||||
start_idx = episode_data["dataset_from_index"][episode_index]
|
||||
end_idx = episode_data["dataset_to_index"][episode_index]
|
||||
|
||||
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
||||
|
||||
# Get task description from the dataset if available
|
||||
task_description = None
|
||||
first_item = dataset[start_idx]
|
||||
if "task" in first_item:
|
||||
task_description = first_item["task"]
|
||||
print(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
||||
|
||||
# Load all frames from the episode
|
||||
frames = []
|
||||
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
||||
item = dataset[idx]
|
||||
# Get image from the item
|
||||
img = item[image_key]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Handle different image formats (C, H, W) or (H, W, C)
|
||||
if img.shape[0] in [1, 3]: # Channel first
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Convert to uint8 if needed
|
||||
if img.dtype != np.uint8:
|
||||
if img.max() <= 1.0:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
else:
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
frames = np.array(frames)
|
||||
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
||||
|
||||
return frames, start_idx, end_idx, task_description
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
model: ReWiNDRewardModel,
|
||||
frames: np.ndarray,
|
||||
task_description: str,
|
||||
batch_size: int = 32
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run ReWiND inference on video frames using the original ReWiND approach.
|
||||
|
||||
This function creates video slices for all frames at once (similar to the original
|
||||
metaworld_label_reward.py), where each slice contains frames from start up to that point.
|
||||
|
||||
Progress Normalization (from original ReWiND dataset.py):
|
||||
- Training: progress = [1, 2, ..., N] / remaining_length
|
||||
where remaining_length = episode_end - sequence_start
|
||||
- Inference: Starting from frame 0, remaining_length = total_episode_length
|
||||
So expected progress for frame i = (i + 1) / total_episode_length
|
||||
|
||||
This function computes both:
|
||||
1. Model predictions (what the model actually predicts)
|
||||
2. Expected progress (ground truth based on frame position)
|
||||
|
||||
Args:
|
||||
model: ReWiND model
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
task_description: Task description text
|
||||
batch_size: Batch size for processing slices
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Model predictions for each frame (num_frames,)
|
||||
- Expected progress for each frame (num_frames,)
|
||||
"""
|
||||
total_frames = len(frames)
|
||||
|
||||
logger.info("Encoding video frames with DINO...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
logger.info("Creating video slices (original ReWiND approach)...")
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
|
||||
# Create video slices: for each frame i, create a sequence of frames [0:i+1]
|
||||
# This matches the original ReWiND inference approach
|
||||
video_slices = []
|
||||
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Slice from start to current frame (inclusive)
|
||||
video_slice = video_embeddings[:i + 1]
|
||||
|
||||
# Pad or subsample to max_length
|
||||
if model.config.subsample_video:
|
||||
video_slice = model.padding_video(video_slice, model.config.max_length)
|
||||
|
||||
video_slices.append(video_slice)
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, max_length, 768)
|
||||
|
||||
# Create last_index_mask to extract the relevant prediction for each slice
|
||||
# For slice i, the last valid frame is at position min(i, max_length-1)
|
||||
max_length = model.config.max_length
|
||||
last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool)
|
||||
|
||||
for i in range(len(video_slices)):
|
||||
last_frame_idx = min(i, max_length - 1)
|
||||
last_index_mask[i, last_frame_idx] = 1
|
||||
|
||||
logger.info("Running ReWiND inference on all slices...")
|
||||
# Process in batches
|
||||
all_progress = []
|
||||
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
||||
batch_video = video_slices[i:i + batch_size].to(model.device)
|
||||
batch_mask = last_index_mask[i:i + batch_size].to(model.device)
|
||||
batch_size_actual = batch_video.shape[0]
|
||||
|
||||
# Replicate text embedding for batch
|
||||
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
||||
|
||||
# Get predictions for all frames in batch
|
||||
progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1)
|
||||
progress_preds = progress_preds.squeeze(-1) # (batch, max_length)
|
||||
|
||||
# Extract predictions using the last_index_mask
|
||||
# This gets the prediction for the last valid frame in each slice
|
||||
batch_progress = progress_preds[batch_mask].cpu().numpy()
|
||||
all_progress.extend(batch_progress)
|
||||
|
||||
predictions = np.array(all_progress)
|
||||
|
||||
# Compute expected progress based on original ReWiND normalization
|
||||
# When starting from frame 0, remaining_length = total_episode_length
|
||||
# Expected progress for frame i = (i + 1) / total_frames
|
||||
expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames
|
||||
|
||||
logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]")
|
||||
logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]")
|
||||
|
||||
return predictions, expected_progress
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
predictions: np.ndarray,
|
||||
expected_progress: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
show_frames: bool = False,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (12, 6)
|
||||
):
|
||||
"""
|
||||
Create visualization of ReWiND predictions with expected progress comparison.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
predictions: Model progress predictions (num_frames,)
|
||||
expected_progress: Expected progress based on frame position (num_frames,)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
show_frames: Whether to include sample frames
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
"""
|
||||
if show_frames:
|
||||
# Create figure with progress plot and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
|
||||
|
||||
# Progress plot
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
else:
|
||||
# Just progress plot
|
||||
fig, ax_progress = plt.subplots(1, 1, figsize=figsize)
|
||||
|
||||
# Plot progress over time
|
||||
frame_indices = np.arange(len(predictions))
|
||||
|
||||
# Plot expected progress (ground truth)
|
||||
ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC',
|
||||
linestyle='--', label='Expected Progress (Linear)', alpha=0.7)
|
||||
|
||||
# Plot model predictions
|
||||
ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB',
|
||||
label='Model Predictions')
|
||||
ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB')
|
||||
|
||||
# Add reference line at 1.0
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
|
||||
# Styling
|
||||
ax_progress.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
|
||||
# Compute alignment metrics
|
||||
mae = np.mean(np.abs(predictions - expected_progress))
|
||||
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
||||
|
||||
# Add statistics box
|
||||
stats_text = (
|
||||
f'Frames: {len(predictions)}\n'
|
||||
f'Model Final: {predictions[-1]:.3f}\n'
|
||||
f'Model Max: {predictions.max():.3f}\n'
|
||||
f'Model Mean: {predictions.mean():.3f}\n'
|
||||
f'MAE: {mae:.3f}\n'
|
||||
f'RMSE: {rmse:.3f}'
|
||||
)
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Show sample frames if requested
|
||||
if show_frames:
|
||||
# Select evenly spaced frames
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
# Create subplot for frames
|
||||
ax_frames = fig.add_subplot(gs[1])
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number and progress value
|
||||
progress_val = predictions[frame_idx]
|
||||
label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=8,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
# Save figure
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Saved visualization to {output_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading ReWiND model from {args.model_id}...")
|
||||
model = ReWiNDRewardModel.from_pretrained(args.model_id)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset {args.dataset_repo}...")
|
||||
dataset = LeRobotDataset(args.dataset_repo)
|
||||
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
||||
|
||||
# Validate episode index
|
||||
if args.episode_index >= len(dataset.meta.episodes):
|
||||
raise ValueError(
|
||||
f"Episode index {args.episode_index} out of range. "
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
# Determine which image key to use
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
|
||||
# Load episode data (this also extracts the task description from the episode)
|
||||
frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key)
|
||||
|
||||
# Use task description from dataset if available, otherwise use command-line argument
|
||||
task_description = dataset_task if dataset_task is not None else args.task_description
|
||||
logger.info(f"Using task description: '{task_description}'")
|
||||
|
||||
# Run inference
|
||||
predictions, expected_progress = run_inference(model, frames, task_description)
|
||||
|
||||
# Create visualization
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png"
|
||||
|
||||
visualize_predictions(
|
||||
frames,
|
||||
predictions,
|
||||
expected_progress,
|
||||
task_description,
|
||||
output_path,
|
||||
show_frames=args.show_frames,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize)
|
||||
)
|
||||
|
||||
# Save predictions and expected progress as numpy arrays
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy"
|
||||
expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy"
|
||||
np.save(predictions_path, predictions)
|
||||
np.save(expected_path, expected_progress)
|
||||
logger.info(f"Saved predictions array to {predictions_path}")
|
||||
logger.info(f"Saved expected progress to {expected_path}")
|
||||
|
||||
# Compute alignment metrics
|
||||
mae = np.mean(np.abs(predictions - expected_progress))
|
||||
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
||||
correlation = np.corrcoef(predictions, expected_progress)[0, 1]
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("INFERENCE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Model: {args.model_id}")
|
||||
logger.info(f"Dataset: {args.dataset_repo}")
|
||||
logger.info(f"Episode: {args.episode_index}")
|
||||
logger.info(f"Task: {task_description}")
|
||||
logger.info(f"Frames: {len(frames)}")
|
||||
logger.info(f"\nModel Predictions:")
|
||||
logger.info(f" Final: {predictions[-1]:.3f}")
|
||||
logger.info(f" Max: {predictions.max():.3f}")
|
||||
logger.info(f" Mean: {predictions.mean():.3f}")
|
||||
logger.info(f" Std: {predictions.std():.3f}")
|
||||
logger.info(f"\nExpected Progress (Linear):")
|
||||
logger.info(f" Final: {expected_progress[-1]:.3f}")
|
||||
logger.info(f" Mean: {expected_progress.mean():.3f}")
|
||||
logger.info(f"\nAlignment Metrics:")
|
||||
logger.info(f" MAE: {mae:.3f}")
|
||||
logger.info(f" RMSE: {rmse:.3f}")
|
||||
logger.info(f" Correlation: {correlation:.3f}")
|
||||
logger.info(f"\nOutput:")
|
||||
logger.info(f" Visualization: {output_path}")
|
||||
logger.info("="*60)
|
||||
|
||||
# Diagnostic warnings
|
||||
if predictions.std() < 0.05:
|
||||
logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)")
|
||||
logger.warning(" Model predictions show very low variance.")
|
||||
logger.warning(" This indicates the model was likely trained with incorrect")
|
||||
logger.warning(" progress normalization (absolute indices instead of remaining length).")
|
||||
elif mae > 0.3:
|
||||
logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)")
|
||||
logger.warning(" Model predictions deviate significantly from expected linear progress.")
|
||||
logger.warning(" Consider retraining with correct progress normalization.")
|
||||
elif correlation < 0.5:
|
||||
logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)")
|
||||
logger.warning(" Model predictions don't align well with linear task progression.")
|
||||
else:
|
||||
logger.info("\n✓ Model predictions show healthy progression!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for SARM (Stage-Aware Reward Model).
|
||||
|
||||
This script loads a trained SARM model and runs inference on a dataset episode,
|
||||
generating visualizations of the predicted task stages and progress over time.
|
||||
|
||||
Example usage:
|
||||
python scripts/visualize_sarm_predictions.py \
|
||||
--model-id username/sarm-model \
|
||||
--dataset-repo lerobot/aloha_sim_insertion_human \
|
||||
--episode-index 0 \
|
||||
--output-dir outputs/sarm_viz \
|
||||
--task-description "insert the peg into the socket"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to trained SARM model"
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the episode to visualize (default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-description",
|
||||
type=str,
|
||||
default="perform the task",
|
||||
help="Task description for the reward model (default: 'perform the task')"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/sarm_inference",
|
||||
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for joint states in dataset. If None, auto-detects from dataset"
|
||||
)
|
||||
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--show-frames",
|
||||
action="store_true",
|
||||
help="Include sample frames in the visualization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-sample-frames",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of sample frames to show (default: 8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--figsize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[14, 8],
|
||||
help="Figure size as width height (default: 14 8)"
|
||||
)
|
||||
|
||||
# Device
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_episode_data(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
image_key: str,
|
||||
state_key: str | None = None
|
||||
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
|
||||
"""
|
||||
Load all frames and states from a specific episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode to load
|
||||
image_key: Key for accessing images in the dataset
|
||||
state_key: Key for accessing joint states (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (frames, states, start_index, end_index, task_description)
|
||||
"""
|
||||
# Get episode boundaries
|
||||
episode_data = dataset.meta.episodes
|
||||
start_idx = episode_data["dataset_from_index"][episode_index]
|
||||
end_idx = episode_data["dataset_to_index"][episode_index]
|
||||
|
||||
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
||||
|
||||
# Auto-detect state key if not provided
|
||||
if state_key is None:
|
||||
first_item = dataset[start_idx]
|
||||
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
|
||||
if state_keys:
|
||||
state_key = state_keys[0]
|
||||
logger.info(f"Auto-detected state key: {state_key}")
|
||||
|
||||
# Get task description from the dataset if available
|
||||
task_description = None
|
||||
first_item = dataset[start_idx]
|
||||
if "task" in first_item:
|
||||
task_description = first_item["task"]
|
||||
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
||||
|
||||
# Load all frames and states from the episode
|
||||
frames = []
|
||||
states = []
|
||||
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
||||
item = dataset[idx]
|
||||
|
||||
# Get image
|
||||
img = item[image_key]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Handle different image formats (C, H, W) or (H, W, C)
|
||||
if img.shape[0] in [1, 3]: # Channel first
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Convert to uint8 if needed
|
||||
if img.dtype != np.uint8:
|
||||
if img.max() <= 1.0:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
else:
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
# Get state if available
|
||||
if state_key and state_key in item:
|
||||
state = item[state_key]
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.cpu().numpy()
|
||||
states.append(state)
|
||||
|
||||
frames = np.array(frames)
|
||||
states = np.array(states) if states else None
|
||||
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
||||
if states is not None:
|
||||
logger.info(f"Loaded states with shape {states.shape}")
|
||||
|
||||
return frames, states, start_idx, end_idx, task_description
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
model: SARMRewardModel,
|
||||
frames: np.ndarray,
|
||||
states: Optional[np.ndarray],
|
||||
task_description: str,
|
||||
batch_size: int = 32
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run SARM inference on video frames and joint states.
|
||||
|
||||
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern:
|
||||
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t]
|
||||
This matches the training pattern where frames are loaded with 30-frame gaps
|
||||
relative to the current frame.
|
||||
|
||||
Args:
|
||||
model: SARM model
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
states: Joint states (num_frames, state_dim)
|
||||
task_description: Task description text
|
||||
batch_size: Batch size for processing slices
|
||||
|
||||
Returns:
|
||||
Tuple of (progress_predictions, stage_predictions)
|
||||
- progress_predictions: (num_frames,)
|
||||
- stage_predictions: (num_frames, num_stages)
|
||||
"""
|
||||
logger.info("Encoding video frames with CLIP...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
logger.info("Creating video slices (SARM approach)...")
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
if states is not None:
|
||||
state_embeddings = torch.tensor(states, dtype=torch.float32)
|
||||
else:
|
||||
state_embeddings = None
|
||||
|
||||
# Create video slices: for each frame i, create a sequence using SARM's pattern
|
||||
# For SARM: 9 frames relative to current, with 30-frame gaps
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
num_frames_model = model.config.num_frames
|
||||
frame_gap = model.config.frame_gap
|
||||
|
||||
video_slices = []
|
||||
state_slices = []
|
||||
last_frame_indices = []
|
||||
|
||||
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# For SARM, create sequence relative to current frame (matching training pattern)
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
# This matches observation_delta_indices: range(-240, 1, 30)
|
||||
|
||||
# Compute frame indices for this slice (relative to current frame i)
|
||||
frame_indices = []
|
||||
for j in range(num_frames_model):
|
||||
# Start from -(num_frames_model-1) * frame_gap and go to 0
|
||||
offset = -(num_frames_model - 1 - j) * frame_gap
|
||||
idx = i + offset
|
||||
|
||||
# Clamp to valid range [0, current_frame]
|
||||
if idx < 0:
|
||||
idx = 0 # Pad with first available frame
|
||||
|
||||
frame_indices.append(idx)
|
||||
|
||||
# Extract slice
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
video_slices.append(video_slice)
|
||||
|
||||
if state_embeddings is not None:
|
||||
state_slice = state_embeddings[frame_indices]
|
||||
state_slices.append(state_slice)
|
||||
|
||||
# Track which frame index corresponds to the "current" frame
|
||||
last_frame_indices.append(min(i, len(frame_indices) - 1))
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||
if state_embeddings is not None:
|
||||
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
||||
else:
|
||||
state_slices = None
|
||||
|
||||
logger.info("Running SARM inference on all slices...")
|
||||
# Process in batches
|
||||
all_progress = []
|
||||
all_stages = []
|
||||
|
||||
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
||||
batch_video = video_slices[i:i + batch_size].to(model.device)
|
||||
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
|
||||
batch_size_actual = batch_video.shape[0]
|
||||
|
||||
# Replicate text embedding for batch
|
||||
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
||||
|
||||
# Get predictions
|
||||
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
|
||||
batch_video, batch_text, batch_states
|
||||
)
|
||||
|
||||
# Extract last frame predictions (the "current" frame)
|
||||
# For SARM, we take the last frame in each sequence
|
||||
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, -1, :].cpu().numpy()
|
||||
|
||||
all_progress.extend(batch_progress)
|
||||
all_stages.extend(batch_stages)
|
||||
|
||||
return np.array(all_progress), np.array(all_stages)
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
progress_predictions: np.ndarray,
|
||||
stage_predictions: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
show_frames: bool = False,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8)
|
||||
):
|
||||
"""
|
||||
Create visualization of SARM predictions.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
progress_predictions: Progress predictions (num_frames,)
|
||||
stage_predictions: Stage probabilities (num_frames, num_stages)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
show_frames: Whether to include sample frames
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
"""
|
||||
num_stages = stage_predictions.shape[1]
|
||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
|
||||
if show_frames:
|
||||
# Create figure with progress plot, stage plot, and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
ax_frames = fig.add_subplot(gs[2])
|
||||
else:
|
||||
# Just progress and stage plots
|
||||
fig = plt.figure(figsize=figsize)
|
||||
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
|
||||
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
|
||||
frame_indices = np.arange(len(progress_predictions))
|
||||
|
||||
# Plot 1: Progress over time
|
||||
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'SARM Task Progress & Stage Prediction\nTask: "{task_description}"',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
|
||||
# Add statistics box
|
||||
stats_text = (
|
||||
f'Frames: {len(progress_predictions)}\n'
|
||||
f'Final Progress: {progress_predictions[-1]:.3f}\n'
|
||||
f'Max Progress: {progress_predictions.max():.3f}\n'
|
||||
f'Mean Progress: {progress_predictions.mean():.3f}'
|
||||
)
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Plot 2: Stage predictions (stacked area plot)
|
||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||
colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)])
|
||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||
|
||||
# Plot 3: Sample frames (if requested)
|
||||
if show_frames:
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=7,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
# Save figure
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Saved visualization to {output_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading SARM model from {args.model_id}...")
|
||||
model = SARMRewardModel.from_pretrained(args.model_id)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset {args.dataset_repo}...")
|
||||
dataset = LeRobotDataset(args.dataset_repo)
|
||||
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
||||
|
||||
# Validate episode index
|
||||
if args.episode_index >= len(dataset.meta.episodes):
|
||||
raise ValueError(
|
||||
f"Episode index {args.episode_index} out of range. "
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
# Determine which image key to use
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
|
||||
# Load episode data
|
||||
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
|
||||
dataset, args.episode_index, image_key, args.state_key
|
||||
)
|
||||
|
||||
# Use task description from dataset if available, otherwise use command-line argument
|
||||
task_description = dataset_task if dataset_task is not None else args.task_description
|
||||
logger.info(f"Using task description: '{task_description}'")
|
||||
|
||||
# Run inference
|
||||
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
|
||||
|
||||
# Create visualization
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
|
||||
visualize_predictions(
|
||||
frames,
|
||||
progress_predictions,
|
||||
stage_predictions,
|
||||
task_description,
|
||||
output_path,
|
||||
show_frames=args.show_frames,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize)
|
||||
)
|
||||
|
||||
# Save predictions as numpy arrays
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
|
||||
np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions)
|
||||
logger.info(f"Saved predictions to {predictions_path}")
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("INFERENCE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Model: {args.model_id}")
|
||||
logger.info(f"Dataset: {args.dataset_repo}")
|
||||
logger.info(f"Episode: {args.episode_index}")
|
||||
logger.info(f"Task: {task_description}")
|
||||
logger.info(f"Frames: {len(frames)}")
|
||||
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
|
||||
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
|
||||
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
|
||||
logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}")
|
||||
logger.info(f"Visualization: {output_path}")
|
||||
logger.info("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user