mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
529 lines
19 KiB
Python
529 lines
19 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Inference script for ReWiND Reward Model.
|
|
|
|
This script loads a trained ReWiND model and runs inference on a dataset episode,
|
|
generating visualizations of the predicted task progression over time.
|
|
|
|
Example usage:
|
|
python scripts/visualize_rewind_predictions.py \
|
|
--model-id username/rewind-model \
|
|
--dataset-repo lerobot/aloha_sim_insertion_human \
|
|
--episode-index 0 \
|
|
--output-dir outputs/rewind_viz \
|
|
--task-description "insert the peg into the socket"
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.gridspec as gridspec
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions")
|
|
|
|
# Model arguments
|
|
parser.add_argument(
|
|
"--model-id",
|
|
type=str,
|
|
required=True,
|
|
help="HuggingFace model ID or local path to trained ReWiND model"
|
|
)
|
|
|
|
# Dataset arguments
|
|
parser.add_argument(
|
|
"--dataset-repo",
|
|
type=str,
|
|
required=True,
|
|
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
|
)
|
|
parser.add_argument(
|
|
"--episode-index",
|
|
type=int,
|
|
default=0,
|
|
help="Index of the episode to visualize (default: 0)"
|
|
)
|
|
parser.add_argument(
|
|
"--task-description",
|
|
type=str,
|
|
default="perform the task",
|
|
help="Task description for the reward model (default: 'perform the task')"
|
|
)
|
|
|
|
# Output arguments
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="outputs/rewind_inference",
|
|
help="Directory to save visualization outputs (default: outputs/rewind_inference)"
|
|
)
|
|
parser.add_argument(
|
|
"--image-key",
|
|
type=str,
|
|
default=None,
|
|
help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key"
|
|
)
|
|
|
|
# Visualization options
|
|
parser.add_argument(
|
|
"--show-frames",
|
|
action="store_true",
|
|
help="Include sample frames in the visualization"
|
|
)
|
|
parser.add_argument(
|
|
"--num-sample-frames",
|
|
type=int,
|
|
default=8,
|
|
help="Number of sample frames to show (default: 8)"
|
|
)
|
|
parser.add_argument(
|
|
"--figsize",
|
|
type=int,
|
|
nargs=2,
|
|
default=[12, 6],
|
|
help="Figure size as width height (default: 12 6)"
|
|
)
|
|
|
|
# Device
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default=None,
|
|
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_episode_data(
|
|
dataset: LeRobotDataset,
|
|
episode_index: int,
|
|
image_key: str
|
|
) -> tuple[np.ndarray, int, int, str]:
|
|
"""
|
|
Load all frames from a specific episode.
|
|
|
|
Args:
|
|
dataset: LeRobotDataset instance
|
|
episode_index: Index of the episode to load
|
|
image_key: Key for accessing images in the dataset
|
|
|
|
Returns:
|
|
Tuple of (frames, start_index, end_index, task_description)
|
|
"""
|
|
# Get episode boundaries
|
|
episode_data = dataset.meta.episodes
|
|
start_idx = episode_data["dataset_from_index"][episode_index]
|
|
end_idx = episode_data["dataset_to_index"][episode_index]
|
|
|
|
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
|
|
|
# Get task description from the dataset if available
|
|
task_description = None
|
|
first_item = dataset[start_idx]
|
|
if "task" in first_item:
|
|
task_description = first_item["task"]
|
|
print(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
|
|
|
# Load all frames from the episode
|
|
frames = []
|
|
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
|
item = dataset[idx]
|
|
# Get image from the item
|
|
img = item[image_key]
|
|
|
|
# Convert to numpy if needed
|
|
if isinstance(img, torch.Tensor):
|
|
img = img.cpu().numpy()
|
|
|
|
# Handle different image formats (C, H, W) or (H, W, C)
|
|
if img.shape[0] in [1, 3]: # Channel first
|
|
img = np.transpose(img, (1, 2, 0))
|
|
|
|
# Convert to uint8 if needed
|
|
if img.dtype != np.uint8:
|
|
if img.max() <= 1.0:
|
|
img = (img * 255).astype(np.uint8)
|
|
else:
|
|
img = img.astype(np.uint8)
|
|
|
|
frames.append(img)
|
|
|
|
frames = np.array(frames)
|
|
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
|
|
|
return frames, start_idx, end_idx, task_description
|
|
|
|
|
|
@torch.no_grad()
|
|
def run_inference(
|
|
model: ReWiNDRewardModel,
|
|
frames: np.ndarray,
|
|
task_description: str,
|
|
batch_size: int = 32
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Run ReWiND inference on video frames using the original ReWiND approach.
|
|
|
|
This function creates video slices for all frames at once (similar to the original
|
|
metaworld_label_reward.py), where each slice contains frames from start up to that point.
|
|
|
|
Progress Normalization (from original ReWiND dataset.py):
|
|
- Training: progress = [1, 2, ..., N] / remaining_length
|
|
where remaining_length = episode_end - sequence_start
|
|
- Inference: Starting from frame 0, remaining_length = total_episode_length
|
|
So expected progress for frame i = (i + 1) / total_episode_length
|
|
|
|
This function computes both:
|
|
1. Model predictions (what the model actually predicts)
|
|
2. Expected progress (ground truth based on frame position)
|
|
|
|
Args:
|
|
model: ReWiND model
|
|
frames: Video frames (num_frames, H, W, C)
|
|
task_description: Task description text
|
|
batch_size: Batch size for processing slices
|
|
|
|
Returns:
|
|
Tuple of:
|
|
- Model predictions for each frame (num_frames,)
|
|
- Expected progress for each frame (num_frames,)
|
|
"""
|
|
total_frames = len(frames)
|
|
|
|
logger.info("Encoding video frames with DINO...")
|
|
video_embeddings = model.encode_images(frames)
|
|
|
|
logger.info("Encoding task description with MiniLM...")
|
|
text_embedding = model.encode_text(task_description)
|
|
|
|
logger.info("Creating video slices (original ReWiND approach)...")
|
|
# Convert to tensors
|
|
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
|
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
|
|
|
# Create video slices: for each frame i, create a sequence of frames [0:i+1]
|
|
# This matches the original ReWiND inference approach
|
|
video_slices = []
|
|
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
|
# Slice from start to current frame (inclusive)
|
|
video_slice = video_embeddings[:i + 1]
|
|
|
|
# Pad or subsample to max_length
|
|
if model.config.subsample_video:
|
|
video_slice = model.padding_video(video_slice, model.config.max_length)
|
|
|
|
video_slices.append(video_slice)
|
|
|
|
video_slices = torch.stack(video_slices) # (num_frames, max_length, 768)
|
|
|
|
# Create last_index_mask to extract the relevant prediction for each slice
|
|
# For slice i, the last valid frame is at position min(i, max_length-1)
|
|
max_length = model.config.max_length
|
|
last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool)
|
|
|
|
for i in range(len(video_slices)):
|
|
last_frame_idx = min(i, max_length - 1)
|
|
last_index_mask[i, last_frame_idx] = 1
|
|
|
|
logger.info("Running ReWiND inference on all slices...")
|
|
# Process in batches
|
|
all_progress = []
|
|
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
|
batch_video = video_slices[i:i + batch_size].to(model.device)
|
|
batch_mask = last_index_mask[i:i + batch_size].to(model.device)
|
|
batch_size_actual = batch_video.shape[0]
|
|
|
|
# Replicate text embedding for batch
|
|
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
|
|
|
# Get predictions for all frames in batch
|
|
progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1)
|
|
progress_preds = progress_preds.squeeze(-1) # (batch, max_length)
|
|
|
|
# Extract predictions using the last_index_mask
|
|
# This gets the prediction for the last valid frame in each slice
|
|
batch_progress = progress_preds[batch_mask].cpu().numpy()
|
|
all_progress.extend(batch_progress)
|
|
|
|
predictions = np.array(all_progress)
|
|
|
|
# Compute expected progress based on original ReWiND normalization
|
|
# When starting from frame 0, remaining_length = total_episode_length
|
|
# Expected progress for frame i = (i + 1) / total_frames
|
|
expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames
|
|
|
|
logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]")
|
|
logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]")
|
|
|
|
return predictions, expected_progress
|
|
|
|
|
|
def visualize_predictions(
|
|
frames: np.ndarray,
|
|
predictions: np.ndarray,
|
|
expected_progress: np.ndarray,
|
|
task_description: str,
|
|
output_path: Path,
|
|
show_frames: bool = False,
|
|
num_sample_frames: int = 8,
|
|
figsize: tuple = (12, 6)
|
|
):
|
|
"""
|
|
Create visualization of ReWiND predictions with expected progress comparison.
|
|
|
|
Args:
|
|
frames: Video frames (num_frames, H, W, C)
|
|
predictions: Model progress predictions (num_frames,)
|
|
expected_progress: Expected progress based on frame position (num_frames,)
|
|
task_description: Task description
|
|
output_path: Path to save the figure
|
|
show_frames: Whether to include sample frames
|
|
num_sample_frames: Number of frames to show
|
|
figsize: Figure size (width, height)
|
|
"""
|
|
if show_frames:
|
|
# Create figure with progress plot and sample frames
|
|
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
|
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
|
|
|
|
# Progress plot
|
|
ax_progress = fig.add_subplot(gs[0])
|
|
else:
|
|
# Just progress plot
|
|
fig, ax_progress = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
# Plot progress over time
|
|
frame_indices = np.arange(len(predictions))
|
|
|
|
# Plot expected progress (ground truth)
|
|
ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC',
|
|
linestyle='--', label='Expected Progress (Linear)', alpha=0.7)
|
|
|
|
# Plot model predictions
|
|
ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB',
|
|
label='Model Predictions')
|
|
ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB')
|
|
|
|
# Add reference line at 1.0
|
|
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
|
|
|
# Styling
|
|
ax_progress.set_xlabel('Frame Index', fontsize=12)
|
|
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
|
ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"',
|
|
fontsize=14, fontweight='bold')
|
|
ax_progress.grid(True, alpha=0.3)
|
|
ax_progress.set_ylim(-0.05, 1.1)
|
|
ax_progress.legend(loc='upper left')
|
|
|
|
# Compute alignment metrics
|
|
mae = np.mean(np.abs(predictions - expected_progress))
|
|
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
|
|
|
# Add statistics box
|
|
stats_text = (
|
|
f'Frames: {len(predictions)}\n'
|
|
f'Model Final: {predictions[-1]:.3f}\n'
|
|
f'Model Max: {predictions.max():.3f}\n'
|
|
f'Model Mean: {predictions.mean():.3f}\n'
|
|
f'MAE: {mae:.3f}\n'
|
|
f'RMSE: {rmse:.3f}'
|
|
)
|
|
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
|
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
|
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|
|
|
# Show sample frames if requested
|
|
if show_frames:
|
|
# Select evenly spaced frames
|
|
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
|
|
|
# Create subplot for frames
|
|
ax_frames = fig.add_subplot(gs[1])
|
|
ax_frames.axis('off')
|
|
|
|
# Create grid for frames
|
|
frame_height = frames[0].shape[0]
|
|
frame_width = frames[0].shape[1]
|
|
|
|
combined_width = frame_width * num_sample_frames
|
|
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
|
|
|
for i, frame_idx in enumerate(frame_indices_to_show):
|
|
frame = frames[frame_idx]
|
|
if frame.shape[-1] == 1:
|
|
frame = np.repeat(frame, 3, axis=-1)
|
|
|
|
# Add frame to combined image
|
|
x_start = i * frame_width
|
|
x_end = (i + 1) * frame_width
|
|
combined_image[:, x_start:x_end] = frame
|
|
|
|
# Add frame number and progress value
|
|
progress_val = predictions[frame_idx]
|
|
label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}'
|
|
|
|
# Draw label on image
|
|
ax_frames.text(x_start + frame_width / 2, -10, label,
|
|
ha='center', va='top', fontsize=8,
|
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
|
|
|
ax_frames.imshow(combined_image)
|
|
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
|
|
|
# Save figure
|
|
plt.tight_layout()
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Saved visualization to {output_path}")
|
|
|
|
plt.close()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Setup device
|
|
if args.device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
device = args.device
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Load model
|
|
logger.info(f"Loading ReWiND model from {args.model_id}...")
|
|
model = ReWiNDRewardModel.from_pretrained(args.model_id)
|
|
model.to(device)
|
|
model.eval()
|
|
logger.info("Model loaded successfully")
|
|
|
|
# Load dataset
|
|
logger.info(f"Loading dataset {args.dataset_repo}...")
|
|
dataset = LeRobotDataset(args.dataset_repo)
|
|
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
|
|
|
# Validate episode index
|
|
if args.episode_index >= len(dataset.meta.episodes):
|
|
raise ValueError(
|
|
f"Episode index {args.episode_index} out of range. "
|
|
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
|
)
|
|
|
|
# Determine which image key to use
|
|
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
|
logger.info(f"Using image key: {image_key}")
|
|
|
|
# Load episode data (this also extracts the task description from the episode)
|
|
frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key)
|
|
|
|
# Use task description from dataset if available, otherwise use command-line argument
|
|
task_description = dataset_task if dataset_task is not None else args.task_description
|
|
logger.info(f"Using task description: '{task_description}'")
|
|
|
|
# Run inference
|
|
predictions, expected_progress = run_inference(model, frames, task_description)
|
|
|
|
# Create visualization
|
|
output_dir = Path(args.output_dir)
|
|
output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png"
|
|
|
|
visualize_predictions(
|
|
frames,
|
|
predictions,
|
|
expected_progress,
|
|
task_description,
|
|
output_path,
|
|
show_frames=args.show_frames,
|
|
num_sample_frames=args.num_sample_frames,
|
|
figsize=tuple(args.figsize)
|
|
)
|
|
|
|
# Save predictions and expected progress as numpy arrays
|
|
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy"
|
|
expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy"
|
|
np.save(predictions_path, predictions)
|
|
np.save(expected_path, expected_progress)
|
|
logger.info(f"Saved predictions array to {predictions_path}")
|
|
logger.info(f"Saved expected progress to {expected_path}")
|
|
|
|
# Compute alignment metrics
|
|
mae = np.mean(np.abs(predictions - expected_progress))
|
|
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
|
correlation = np.corrcoef(predictions, expected_progress)[0, 1]
|
|
|
|
# Print summary
|
|
logger.info("\n" + "="*60)
|
|
logger.info("INFERENCE SUMMARY")
|
|
logger.info("="*60)
|
|
logger.info(f"Model: {args.model_id}")
|
|
logger.info(f"Dataset: {args.dataset_repo}")
|
|
logger.info(f"Episode: {args.episode_index}")
|
|
logger.info(f"Task: {task_description}")
|
|
logger.info(f"Frames: {len(frames)}")
|
|
logger.info(f"\nModel Predictions:")
|
|
logger.info(f" Final: {predictions[-1]:.3f}")
|
|
logger.info(f" Max: {predictions.max():.3f}")
|
|
logger.info(f" Mean: {predictions.mean():.3f}")
|
|
logger.info(f" Std: {predictions.std():.3f}")
|
|
logger.info(f"\nExpected Progress (Linear):")
|
|
logger.info(f" Final: {expected_progress[-1]:.3f}")
|
|
logger.info(f" Mean: {expected_progress.mean():.3f}")
|
|
logger.info(f"\nAlignment Metrics:")
|
|
logger.info(f" MAE: {mae:.3f}")
|
|
logger.info(f" RMSE: {rmse:.3f}")
|
|
logger.info(f" Correlation: {correlation:.3f}")
|
|
logger.info(f"\nOutput:")
|
|
logger.info(f" Visualization: {output_path}")
|
|
logger.info("="*60)
|
|
|
|
# Diagnostic warnings
|
|
if predictions.std() < 0.05:
|
|
logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)")
|
|
logger.warning(" Model predictions show very low variance.")
|
|
logger.warning(" This indicates the model was likely trained with incorrect")
|
|
logger.warning(" progress normalization (absolute indices instead of remaining length).")
|
|
elif mae > 0.3:
|
|
logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)")
|
|
logger.warning(" Model predictions deviate significantly from expected linear progress.")
|
|
logger.warning(" Consider retraining with correct progress normalization.")
|
|
elif correlation < 0.5:
|
|
logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)")
|
|
logger.warning(" Model predictions don't align well with linear task progression.")
|
|
else:
|
|
logger.info("\n✓ Model predictions show healthy progression!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|