Files
lerobot/benchmark_rlearn_vs_rewind.py
T
2025-08-29 15:33:45 +02:00

346 lines
12 KiB
Python

#!/usr/bin/env python
"""
Benchmark script to compare forward pass speed between RLearn and ReWiND implementations.
This script compares the inference speed of:
1. RLearn model (lerobot implementation)
2. ReWiND model (reference implementation)
Both models use the same backbone architectures (DINOv2 + sentence-transformers)
and implement similar reward modeling approaches.
"""
import time
from itertools import chain
from random import random
import einx
import torch
import torch.nn.functional as F
from einops import pack, rearrange, repeat, unpack
from hl_gauss_pytorch import HLGaussLayer
from sentence_transformers import SentenceTransformer
from torch import nn
from torch.nn.utils.rnn import pad_sequence
# ReWiND implementation (copied from user's context)
from transformers import AutoImageProcessor, AutoModel
from vit_pytorch.accept_video_wrapper import AcceptVideoWrapper
from x_mlps_pytorch import Feedforwards
from x_transformers import Decoder
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
# RLearn implementation
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
# ReWiND helper functions
def exists(v):
return v is not None
def satisfy_prob(prob):
return random() < prob
def mask_from_lens(lens):
seq = torch.arange(lens.amax().item(), device=lens.device)
mask = einx.less("n, b -> b n", seq, lens)
return mask
def randint(min_value: int, max_value: torch.Tensor):
value_range = (max_value - min_value).float()
return ((value_range * torch.rand_like(value_range)) + min_value).round().clamp(min=min_value).long()
# ReWiND DinoImageEmbedder
class DinoImageEmbedder(nn.Module):
def __init__(self):
super().__init__()
self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
self.image_model = AutoModel.from_pretrained("facebook/dinov2-base")
def forward(self, images):
model_inputs = self.image_processor(images, return_tensors="pt")
outputs = self.image_model(**model_inputs)
last_hidden_states = outputs[0]
return last_hidden_states[:, 0] # cls
# ReWiND RewardModel
class RewardModel(nn.Module):
def __init__(
self,
decoder: dict | Decoder = dict(dim=768, depth=4, heads=8, attn_dim_head=64),
image_model: nn.Module | None = None,
mlp_predictor_depth=3,
reward_bins=10,
max_video_frames=16,
dim_image_embed=768,
num_register_tokens=4,
lang_per_token_embed=True,
sentence_transformer_path="sentence-transformers/all-MiniLM-L12-v2",
categorical_rewards=False,
use_hl_gauss_loss=True,
reward_min_value=0.0,
reward_max_value=1.0,
reward_hl_gauss_loss_num_bins=20,
):
super().__init__()
self.lang_per_token_embed = lang_per_token_embed
self.mini_lm = SentenceTransformer(sentence_transformer_path)
mini_lm_dim = self.mini_lm.encode(["__"]).shape[-1]
if not exists(image_model):
image_model = DinoImageEmbedder()
self.video_embed = AcceptVideoWrapper(image_model)
self.decoder = Decoder(**decoder)
dim = self.decoder.dim
self.first_pos_emb = nn.Parameter(torch.randn(dim) * 1e-2)
self.to_lang_tokens = nn.Linear(mini_lm_dim, dim)
self.to_video_tokens = nn.Linear(dim_image_embed, dim)
self.mlp_predictor = Feedforwards(
dim=dim, dim_out=reward_bins if categorical_rewards else None, depth=mlp_predictor_depth
)
self.num_register_tokens = num_register_tokens
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
self.categorical_rewards = categorical_rewards
self.hl_gauss_layer = HLGaussLayer(
dim=dim,
use_regression=not use_hl_gauss_loss,
hl_gauss_loss=dict(
min_value=reward_min_value,
max_value=reward_max_value,
num_bins=reward_hl_gauss_loss_num_bins,
),
)
def parameters(self):
return chain(
self.decoder.parameters(),
iter((self.video_embed.pos_emb,)),
self.to_lang_tokens.parameters(),
self.to_video_tokens.parameters(),
self.mlp_predictor.parameters(),
self.hl_gauss_layer.parameters(),
)
def forward(
self,
commands: list[str],
video, # (b c t h w)
extra_embed_tokens=None, # (b n d)
rewards=None,
video_lens=None,
):
batch = video.shape[0]
assert len(commands) == batch
device = video.device
mask = None
# register tokens
register_tokens = repeat(self.register_tokens, "n d -> b n d", b=batch)
# language embed
lang_embeds = self.mini_lm.encode(
commands,
output_value="token_embeddings" if self.lang_per_token_embed else "sentence_embedding",
convert_to_numpy=False,
)
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
if self.lang_per_token_embed:
lens = torch.tensor([t.shape[0] for t in lang_embeds], device=device)
mask = mask_from_lens(lens)
# extra embeds
if not exists(extra_embed_tokens):
extra_embed_tokens = register_tokens[:, 0:0]
elif exists(extra_embed_tokens) and exists(mask):
mask = F.pad(mask, (0, extra_embed_tokens.shape[-2]), value=True)
# video embeds
video_embeds = self.video_embed(video, eval_with_no_grad=True)
if self.lang_per_token_embed:
mask = F.pad(mask, (0, video_embeds.shape[1] + self.num_register_tokens), value=True)
# linear projections
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# add video start positional embedding
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
first_video_token = first_video_token + repeat(self.first_pos_emb, "d -> b 1 d", b=batch)
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
# pack all tokens for attention
tokens, lang_video_packed_shape = pack(
(lang_tokens, register_tokens, extra_embed_tokens, video_tokens), "b * d"
)
# attention
attended = self.decoder(tokens, mask=mask)
# unpack and project the video tokens to logits to train reward predictor
_, _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, "b * d")
video_frame_embed_or_logits = self.mlp_predictor(attended_video_tokens)
# determine video masking for loss
video_mask = None
if exists(video_lens):
video_mask = mask_from_lens(video_lens)
max_video_len = video_lens.amax().item()
video_frame_embed_or_logits = video_frame_embed_or_logits[:, :max_video_len]
if exists(rewards):
rewards = rewards[:, :max_video_len]
rewards = einx.where("b t, b t,", video_mask, rewards, -1)
# return raw prediction or loss
return_loss = exists(rewards)
if not return_loss:
if self.categorical_rewards:
return video_frame_embed_or_logits
else:
return self.hl_gauss_layer(video_frame_embed_or_logits)
# calculate loss
if self.categorical_rewards:
assert rewards.dtype in (torch.long, torch.int)
loss = F.cross_entropy(
rearrange(video_frame_embed_or_logits, "b t l -> b l t"), rewards, ignore_index=-1
)
else:
assert rewards.dtype == torch.float
loss = self.hl_gauss_layer(video_frame_embed_or_logits, rewards, mask=video_mask)
return loss
def benchmark_models():
"""Benchmark forward pass speed of RLearn vs ReWiND models."""
print("Setting up models and test data...")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Test data
batch_size = 2
num_frames = 16
height, width = 224, 224
commands = [
"pick up the blue ball and put it in the red tray",
"pick up the red cube and put it in the green bin",
]
# Create video tensor (B, C, T, H, W) for ReWiND
video_rewind = torch.rand(batch_size, 3, num_frames, height, width, device=device)
# Create video tensor (B, T, C, H, W) for RLearn
video_rlearn = video_rewind.permute(0, 2, 1, 3, 4) # (B, T, C, H, W)
# Create batch dict for RLearn
batch = {OBS_IMAGES: video_rlearn, OBS_LANGUAGE: commands}
# Initialize RLearn model
print("Initializing RLearn model...")
rlearn_config = RLearNConfig()
rlearn_model = RLearNPolicy(rlearn_config).to(device)
rlearn_model.eval()
# Initialize ReWiND model
print("Initializing ReWiND model...")
rewind_model = RewardModel().to(device)
rewind_model.eval()
# Warm up both models
print("Warming up models...")
with torch.no_grad():
for _ in range(3):
_ = rlearn_model.predict_rewards(batch)
_ = rewind_model(commands, video_rewind)
# Benchmark RLearn
print("\nBenchmarking RLearn model...")
rlearn_times = []
with torch.no_grad():
for i in range(100):
start_time = time.perf_counter()
rewards = rlearn_model.predict_rewards(batch)
torch.cuda.synchronize() if torch.cuda.is_available() else None
end_time = time.perf_counter()
rlearn_times.append(end_time - start_time)
# Benchmark ReWiND
print("Benchmarking ReWiND model...")
rewind_times = []
with torch.no_grad():
for i in range(100):
start_time = time.perf_counter()
rewards = rewind_model(commands, video_rewind)
torch.cuda.synchronize() if torch.cuda.is_available() else None
end_time = time.perf_counter()
rewind_times.append(end_time - start_time)
# Calculate statistics
rlearn_avg = sum(rlearn_times) / len(rlearn_times) * 1000 # Convert to ms
rlearn_std = torch.tensor(rlearn_times).std().item() * 1000
rlearn_min = min(rlearn_times) * 1000
rlearn_max = max(rlearn_times) * 1000
rewind_avg = sum(rewind_times) / len(rewind_times) * 1000
rewind_std = torch.tensor(rewind_times).std().item() * 1000
rewind_min = min(rewind_times) * 1000
rewind_max = max(rewind_times) * 1000
# Print results
print("\n" + "=" * 60)
print("BENCHMARK RESULTS (100 runs, inference only)")
print("=" * 60)
print(f"RLearn avg: {rlearn_avg:.2f} ms")
print(f"RLearn std: {rlearn_std:.2f} ms")
print(f"RLearn min: {rlearn_min:.2f} ms")
print(f"RLearn max: {rlearn_max:.2f} ms")
print(f"ReWiND avg: {rewind_avg:.2f} ms")
print(f"ReWiND std: {rewind_std:.2f} ms")
print(f"ReWiND min: {rewind_min:.2f} ms")
print(f"ReWiND max: {rewind_max:.2f} ms")
speedup = rlearn_avg / rewind_avg if rewind_avg > 0 else float("inf")
print(f"Speedup (RLearn/ReWiND): {speedup:.2f}x")
print(f"{'RLearn is faster!' if speedup > 1 else 'ReWiND is faster!'}")
# Verify outputs are similar in shape
print("\nOutput shapes:")
with torch.no_grad():
rlearn_output = rlearn_model.predict_rewards(batch)
rewind_output = rewind_model(commands, video_rewind)
print(f"RLearn output shape: {rlearn_output.shape}")
print(f"ReWiND output shape: {rewind_output.shape}")
if rlearn_output.shape == rewind_output.shape:
print("✓ Output shapes match")
else:
print("⚠ Output shapes differ - this may indicate implementation differences")
print("\nBenchmark completed successfully!")
if __name__ == "__main__":
benchmark_models()