Remove rewind, use clip tokenizer

This commit is contained in:
Pepijn
2025-11-26 21:06:20 +01:00
parent 425eced2de
commit 3ed0425d2c
17 changed files with 172 additions and 2668 deletions
-528
View File
@@ -1,528 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for ReWiND Reward Model.
This script loads a trained ReWiND model and runs inference on a dataset episode,
generating visualizations of the predicted task progression over time.
Example usage:
python scripts/visualize_rewind_predictions.py \
--model-id username/rewind-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/rewind_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained ReWiND model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/rewind_inference",
help="Directory to save visualization outputs (default: outputs/rewind_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[12, 6],
help="Figure size as width height (default: 12 6)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str
) -> tuple[np.ndarray, int, int, str]:
"""
Load all frames from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
Returns:
Tuple of (frames, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
print(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames from the episode
frames = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image from the item
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
frames = np.array(frames)
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
return frames, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: ReWiNDRewardModel,
frames: np.ndarray,
task_description: str,
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run ReWiND inference on video frames using the original ReWiND approach.
This function creates video slices for all frames at once (similar to the original
metaworld_label_reward.py), where each slice contains frames from start up to that point.
Progress Normalization (from original ReWiND dataset.py):
- Training: progress = [1, 2, ..., N] / remaining_length
where remaining_length = episode_end - sequence_start
- Inference: Starting from frame 0, remaining_length = total_episode_length
So expected progress for frame i = (i + 1) / total_episode_length
This function computes both:
1. Model predictions (what the model actually predicts)
2. Expected progress (ground truth based on frame position)
Args:
model: ReWiND model
frames: Video frames (num_frames, H, W, C)
task_description: Task description text
batch_size: Batch size for processing slices
Returns:
Tuple of:
- Model predictions for each frame (num_frames,)
- Expected progress for each frame (num_frames,)
"""
total_frames = len(frames)
logger.info("Encoding video frames with DINO...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with MiniLM...")
text_embedding = model.encode_text(task_description)
logger.info("Creating video slices (original ReWiND approach)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
# Create video slices: for each frame i, create a sequence of frames [0:i+1]
# This matches the original ReWiND inference approach
video_slices = []
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Slice from start to current frame (inclusive)
video_slice = video_embeddings[:i + 1]
# Pad or subsample to max_length
if model.config.subsample_video:
video_slice = model.padding_video(video_slice, model.config.max_length)
video_slices.append(video_slice)
video_slices = torch.stack(video_slices) # (num_frames, max_length, 768)
# Create last_index_mask to extract the relevant prediction for each slice
# For slice i, the last valid frame is at position min(i, max_length-1)
max_length = model.config.max_length
last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool)
for i in range(len(video_slices)):
last_frame_idx = min(i, max_length - 1)
last_index_mask[i, last_frame_idx] = 1
logger.info("Running ReWiND inference on all slices...")
# Process in batches
all_progress = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_mask = last_index_mask[i:i + batch_size].to(model.device)
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions for all frames in batch
progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1)
progress_preds = progress_preds.squeeze(-1) # (batch, max_length)
# Extract predictions using the last_index_mask
# This gets the prediction for the last valid frame in each slice
batch_progress = progress_preds[batch_mask].cpu().numpy()
all_progress.extend(batch_progress)
predictions = np.array(all_progress)
# Compute expected progress based on original ReWiND normalization
# When starting from frame 0, remaining_length = total_episode_length
# Expected progress for frame i = (i + 1) / total_frames
expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames
logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]")
logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]")
return predictions, expected_progress
def visualize_predictions(
frames: np.ndarray,
predictions: np.ndarray,
expected_progress: np.ndarray,
task_description: str,
output_path: Path,
show_frames: bool = False,
num_sample_frames: int = 8,
figsize: tuple = (12, 6)
):
"""
Create visualization of ReWiND predictions with expected progress comparison.
Args:
frames: Video frames (num_frames, H, W, C)
predictions: Model progress predictions (num_frames,)
expected_progress: Expected progress based on frame position (num_frames,)
task_description: Task description
output_path: Path to save the figure
show_frames: Whether to include sample frames
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
"""
if show_frames:
# Create figure with progress plot and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
# Progress plot
ax_progress = fig.add_subplot(gs[0])
else:
# Just progress plot
fig, ax_progress = plt.subplots(1, 1, figsize=figsize)
# Plot progress over time
frame_indices = np.arange(len(predictions))
# Plot expected progress (ground truth)
ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC',
linestyle='--', label='Expected Progress (Linear)', alpha=0.7)
# Plot model predictions
ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB',
label='Model Predictions')
ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB')
# Add reference line at 1.0
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
# Styling
ax_progress.set_xlabel('Frame Index', fontsize=12)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"',
fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Compute alignment metrics
mae = np.mean(np.abs(predictions - expected_progress))
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
# Add statistics box
stats_text = (
f'Frames: {len(predictions)}\n'
f'Model Final: {predictions[-1]:.3f}\n'
f'Model Max: {predictions.max():.3f}\n'
f'Model Mean: {predictions.mean():.3f}\n'
f'MAE: {mae:.3f}\n'
f'RMSE: {rmse:.3f}'
)
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Show sample frames if requested
if show_frames:
# Select evenly spaced frames
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
# Create subplot for frames
ax_frames = fig.add_subplot(gs[1])
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number and progress value
progress_val = predictions[frame_idx]
label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=8,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
# Save figure
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading ReWiND model from {args.model_id}...")
model = ReWiNDRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
# Determine which image key to use
image_key = args.image_key if args.image_key is not None else model.config.image_key
logger.info(f"Using image key: {image_key}")
# Load episode data (this also extracts the task description from the episode)
frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
predictions, expected_progress = run_inference(model, frames, task_description)
# Create visualization
output_dir = Path(args.output_dir)
output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
predictions,
expected_progress,
task_description,
output_path,
show_frames=args.show_frames,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize)
)
# Save predictions and expected progress as numpy arrays
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy"
expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy"
np.save(predictions_path, predictions)
np.save(expected_path, expected_progress)
logger.info(f"Saved predictions array to {predictions_path}")
logger.info(f"Saved expected progress to {expected_path}")
# Compute alignment metrics
mae = np.mean(np.abs(predictions - expected_progress))
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
correlation = np.corrcoef(predictions, expected_progress)[0, 1]
# Print summary
logger.info("\n" + "="*60)
logger.info("INFERENCE SUMMARY")
logger.info("="*60)
logger.info(f"Model: {args.model_id}")
logger.info(f"Dataset: {args.dataset_repo}")
logger.info(f"Episode: {args.episode_index}")
logger.info(f"Task: {task_description}")
logger.info(f"Frames: {len(frames)}")
logger.info(f"\nModel Predictions:")
logger.info(f" Final: {predictions[-1]:.3f}")
logger.info(f" Max: {predictions.max():.3f}")
logger.info(f" Mean: {predictions.mean():.3f}")
logger.info(f" Std: {predictions.std():.3f}")
logger.info(f"\nExpected Progress (Linear):")
logger.info(f" Final: {expected_progress[-1]:.3f}")
logger.info(f" Mean: {expected_progress.mean():.3f}")
logger.info(f"\nAlignment Metrics:")
logger.info(f" MAE: {mae:.3f}")
logger.info(f" RMSE: {rmse:.3f}")
logger.info(f" Correlation: {correlation:.3f}")
logger.info(f"\nOutput:")
logger.info(f" Visualization: {output_path}")
logger.info("="*60)
# Diagnostic warnings
if predictions.std() < 0.05:
logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)")
logger.warning(" Model predictions show very low variance.")
logger.warning(" This indicates the model was likely trained with incorrect")
logger.warning(" progress normalization (absolute indices instead of remaining length).")
elif mae > 0.3:
logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)")
logger.warning(" Model predictions deviate significantly from expected linear progress.")
logger.warning(" Consider retraining with correct progress normalization.")
elif correlation < 0.5:
logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)")
logger.warning(" Model predictions don't align well with linear task progression.")
else:
logger.info("\n✓ Model predictions show healthy progression!")
if __name__ == "__main__":
main()
+1 -1
View File
@@ -244,7 +244,7 @@ def run_inference(
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with MiniLM...")
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values