mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
add benchmark
This commit is contained in:
@@ -0,0 +1,345 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Benchmark script to compare forward pass speed between RLearn and ReWiND implementations.
|
||||||
|
|
||||||
|
This script compares the inference speed of:
|
||||||
|
1. RLearn model (lerobot implementation)
|
||||||
|
2. ReWiND model (reference implementation)
|
||||||
|
|
||||||
|
Both models use the same backbone architectures (DINOv2 + sentence-transformers)
|
||||||
|
and implement similar reward modeling approaches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from itertools import chain
|
||||||
|
from random import random
|
||||||
|
|
||||||
|
import einx
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import pack, rearrange, repeat, unpack
|
||||||
|
from hl_gauss_pytorch import HLGaussLayer
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
# ReWiND implementation (copied from user's context)
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
from vit_pytorch.accept_video_wrapper import AcceptVideoWrapper
|
||||||
|
from x_mlps_pytorch import Feedforwards
|
||||||
|
from x_transformers import Decoder
|
||||||
|
|
||||||
|
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE
|
||||||
|
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||||
|
|
||||||
|
# RLearn implementation
|
||||||
|
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||||
|
|
||||||
|
|
||||||
|
# ReWiND helper functions
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
|
def satisfy_prob(prob):
|
||||||
|
return random() < prob
|
||||||
|
|
||||||
|
|
||||||
|
def mask_from_lens(lens):
|
||||||
|
seq = torch.arange(lens.amax().item(), device=lens.device)
|
||||||
|
mask = einx.less("n, b -> b n", seq, lens)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def randint(min_value: int, max_value: torch.Tensor):
|
||||||
|
value_range = (max_value - min_value).float()
|
||||||
|
return ((value_range * torch.rand_like(value_range)) + min_value).round().clamp(min=min_value).long()
|
||||||
|
|
||||||
|
|
||||||
|
# ReWiND DinoImageEmbedder
|
||||||
|
class DinoImageEmbedder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
||||||
|
self.image_model = AutoModel.from_pretrained("facebook/dinov2-base")
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
model_inputs = self.image_processor(images, return_tensors="pt")
|
||||||
|
outputs = self.image_model(**model_inputs)
|
||||||
|
last_hidden_states = outputs[0]
|
||||||
|
return last_hidden_states[:, 0] # cls
|
||||||
|
|
||||||
|
|
||||||
|
# ReWiND RewardModel
|
||||||
|
class RewardModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoder: dict | Decoder = dict(dim=768, depth=4, heads=8, attn_dim_head=64),
|
||||||
|
image_model: nn.Module | None = None,
|
||||||
|
mlp_predictor_depth=3,
|
||||||
|
reward_bins=10,
|
||||||
|
max_video_frames=16,
|
||||||
|
dim_image_embed=768,
|
||||||
|
num_register_tokens=4,
|
||||||
|
lang_per_token_embed=True,
|
||||||
|
sentence_transformer_path="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
|
categorical_rewards=False,
|
||||||
|
use_hl_gauss_loss=True,
|
||||||
|
reward_min_value=0.0,
|
||||||
|
reward_max_value=1.0,
|
||||||
|
reward_hl_gauss_loss_num_bins=20,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.lang_per_token_embed = lang_per_token_embed
|
||||||
|
self.mini_lm = SentenceTransformer(sentence_transformer_path)
|
||||||
|
mini_lm_dim = self.mini_lm.encode(["__"]).shape[-1]
|
||||||
|
|
||||||
|
if not exists(image_model):
|
||||||
|
image_model = DinoImageEmbedder()
|
||||||
|
|
||||||
|
self.video_embed = AcceptVideoWrapper(image_model)
|
||||||
|
self.decoder = Decoder(**decoder)
|
||||||
|
dim = self.decoder.dim
|
||||||
|
|
||||||
|
self.first_pos_emb = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
self.to_lang_tokens = nn.Linear(mini_lm_dim, dim)
|
||||||
|
self.to_video_tokens = nn.Linear(dim_image_embed, dim)
|
||||||
|
|
||||||
|
self.mlp_predictor = Feedforwards(
|
||||||
|
dim=dim, dim_out=reward_bins if categorical_rewards else None, depth=mlp_predictor_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_register_tokens = num_register_tokens
|
||||||
|
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||||
|
self.categorical_rewards = categorical_rewards
|
||||||
|
|
||||||
|
self.hl_gauss_layer = HLGaussLayer(
|
||||||
|
dim=dim,
|
||||||
|
use_regression=not use_hl_gauss_loss,
|
||||||
|
hl_gauss_loss=dict(
|
||||||
|
min_value=reward_min_value,
|
||||||
|
max_value=reward_max_value,
|
||||||
|
num_bins=reward_hl_gauss_loss_num_bins,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
return chain(
|
||||||
|
self.decoder.parameters(),
|
||||||
|
iter((self.video_embed.pos_emb,)),
|
||||||
|
self.to_lang_tokens.parameters(),
|
||||||
|
self.to_video_tokens.parameters(),
|
||||||
|
self.mlp_predictor.parameters(),
|
||||||
|
self.hl_gauss_layer.parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
commands: list[str],
|
||||||
|
video, # (b c t h w)
|
||||||
|
extra_embed_tokens=None, # (b n d)
|
||||||
|
rewards=None,
|
||||||
|
video_lens=None,
|
||||||
|
):
|
||||||
|
batch = video.shape[0]
|
||||||
|
assert len(commands) == batch
|
||||||
|
|
||||||
|
device = video.device
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
# register tokens
|
||||||
|
register_tokens = repeat(self.register_tokens, "n d -> b n d", b=batch)
|
||||||
|
|
||||||
|
# language embed
|
||||||
|
lang_embeds = self.mini_lm.encode(
|
||||||
|
commands,
|
||||||
|
output_value="token_embeddings" if self.lang_per_token_embed else "sentence_embedding",
|
||||||
|
convert_to_numpy=False,
|
||||||
|
)
|
||||||
|
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||||
|
|
||||||
|
if self.lang_per_token_embed:
|
||||||
|
lens = torch.tensor([t.shape[0] for t in lang_embeds], device=device)
|
||||||
|
mask = mask_from_lens(lens)
|
||||||
|
|
||||||
|
# extra embeds
|
||||||
|
if not exists(extra_embed_tokens):
|
||||||
|
extra_embed_tokens = register_tokens[:, 0:0]
|
||||||
|
|
||||||
|
elif exists(extra_embed_tokens) and exists(mask):
|
||||||
|
mask = F.pad(mask, (0, extra_embed_tokens.shape[-2]), value=True)
|
||||||
|
|
||||||
|
# video embeds
|
||||||
|
video_embeds = self.video_embed(video, eval_with_no_grad=True)
|
||||||
|
|
||||||
|
if self.lang_per_token_embed:
|
||||||
|
mask = F.pad(mask, (0, video_embeds.shape[1] + self.num_register_tokens), value=True)
|
||||||
|
|
||||||
|
# linear projections
|
||||||
|
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||||
|
video_tokens = self.to_video_tokens(video_embeds)
|
||||||
|
|
||||||
|
# add video start positional embedding
|
||||||
|
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
||||||
|
first_video_token = first_video_token + repeat(self.first_pos_emb, "d -> b 1 d", b=batch)
|
||||||
|
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
||||||
|
|
||||||
|
# pack all tokens for attention
|
||||||
|
tokens, lang_video_packed_shape = pack(
|
||||||
|
(lang_tokens, register_tokens, extra_embed_tokens, video_tokens), "b * d"
|
||||||
|
)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attended = self.decoder(tokens, mask=mask)
|
||||||
|
|
||||||
|
# unpack and project the video tokens to logits to train reward predictor
|
||||||
|
_, _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, "b * d")
|
||||||
|
|
||||||
|
video_frame_embed_or_logits = self.mlp_predictor(attended_video_tokens)
|
||||||
|
|
||||||
|
# determine video masking for loss
|
||||||
|
video_mask = None
|
||||||
|
if exists(video_lens):
|
||||||
|
video_mask = mask_from_lens(video_lens)
|
||||||
|
max_video_len = video_lens.amax().item()
|
||||||
|
video_frame_embed_or_logits = video_frame_embed_or_logits[:, :max_video_len]
|
||||||
|
if exists(rewards):
|
||||||
|
rewards = rewards[:, :max_video_len]
|
||||||
|
rewards = einx.where("b t, b t,", video_mask, rewards, -1)
|
||||||
|
|
||||||
|
# return raw prediction or loss
|
||||||
|
return_loss = exists(rewards)
|
||||||
|
if not return_loss:
|
||||||
|
if self.categorical_rewards:
|
||||||
|
return video_frame_embed_or_logits
|
||||||
|
else:
|
||||||
|
return self.hl_gauss_layer(video_frame_embed_or_logits)
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
if self.categorical_rewards:
|
||||||
|
assert rewards.dtype in (torch.long, torch.int)
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
rearrange(video_frame_embed_or_logits, "b t l -> b l t"), rewards, ignore_index=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert rewards.dtype == torch.float
|
||||||
|
loss = self.hl_gauss_layer(video_frame_embed_or_logits, rewards, mask=video_mask)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_models():
|
||||||
|
"""Benchmark forward pass speed of RLearn vs ReWiND models."""
|
||||||
|
|
||||||
|
print("Setting up models and test data...")
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
batch_size = 2
|
||||||
|
num_frames = 16
|
||||||
|
height, width = 224, 224
|
||||||
|
|
||||||
|
commands = [
|
||||||
|
"pick up the blue ball and put it in the red tray",
|
||||||
|
"pick up the red cube and put it in the green bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create video tensor (B, C, T, H, W) for ReWiND
|
||||||
|
video_rewind = torch.rand(batch_size, 3, num_frames, height, width, device=device)
|
||||||
|
|
||||||
|
# Create video tensor (B, T, C, H, W) for RLearn
|
||||||
|
video_rlearn = video_rewind.permute(0, 2, 1, 3, 4) # (B, T, C, H, W)
|
||||||
|
|
||||||
|
# Create batch dict for RLearn
|
||||||
|
batch = {OBS_IMAGES: video_rlearn, OBS_LANGUAGE: commands}
|
||||||
|
|
||||||
|
# Initialize RLearn model
|
||||||
|
print("Initializing RLearn model...")
|
||||||
|
rlearn_config = RLearNConfig()
|
||||||
|
rlearn_model = RLearNPolicy(rlearn_config).to(device)
|
||||||
|
rlearn_model.eval()
|
||||||
|
|
||||||
|
# Initialize ReWiND model
|
||||||
|
print("Initializing ReWiND model...")
|
||||||
|
rewind_model = RewardModel().to(device)
|
||||||
|
rewind_model.eval()
|
||||||
|
|
||||||
|
# Warm up both models
|
||||||
|
print("Warming up models...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(3):
|
||||||
|
_ = rlearn_model.predict_rewards(batch)
|
||||||
|
_ = rewind_model(commands, video_rewind)
|
||||||
|
|
||||||
|
# Benchmark RLearn
|
||||||
|
print("\nBenchmarking RLearn model...")
|
||||||
|
rlearn_times = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(100):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
rewards = rlearn_model.predict_rewards(batch)
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
rlearn_times.append(end_time - start_time)
|
||||||
|
|
||||||
|
# Benchmark ReWiND
|
||||||
|
print("Benchmarking ReWiND model...")
|
||||||
|
rewind_times = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(100):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
rewards = rewind_model(commands, video_rewind)
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
rewind_times.append(end_time - start_time)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
rlearn_avg = sum(rlearn_times) / len(rlearn_times) * 1000 # Convert to ms
|
||||||
|
rlearn_std = torch.tensor(rlearn_times).std().item() * 1000
|
||||||
|
rlearn_min = min(rlearn_times) * 1000
|
||||||
|
rlearn_max = max(rlearn_times) * 1000
|
||||||
|
|
||||||
|
rewind_avg = sum(rewind_times) / len(rewind_times) * 1000
|
||||||
|
rewind_std = torch.tensor(rewind_times).std().item() * 1000
|
||||||
|
rewind_min = min(rewind_times) * 1000
|
||||||
|
rewind_max = max(rewind_times) * 1000
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("BENCHMARK RESULTS (100 runs, inference only)")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"RLearn avg: {rlearn_avg:.2f} ms")
|
||||||
|
print(f"RLearn std: {rlearn_std:.2f} ms")
|
||||||
|
print(f"RLearn min: {rlearn_min:.2f} ms")
|
||||||
|
print(f"RLearn max: {rlearn_max:.2f} ms")
|
||||||
|
print(f"ReWiND avg: {rewind_avg:.2f} ms")
|
||||||
|
print(f"ReWiND std: {rewind_std:.2f} ms")
|
||||||
|
print(f"ReWiND min: {rewind_min:.2f} ms")
|
||||||
|
print(f"ReWiND max: {rewind_max:.2f} ms")
|
||||||
|
|
||||||
|
speedup = rlearn_avg / rewind_avg if rewind_avg > 0 else float("inf")
|
||||||
|
print(f"Speedup (RLearn/ReWiND): {speedup:.2f}x")
|
||||||
|
print(f"{'RLearn is faster!' if speedup > 1 else 'ReWiND is faster!'}")
|
||||||
|
# Verify outputs are similar in shape
|
||||||
|
print("\nOutput shapes:")
|
||||||
|
with torch.no_grad():
|
||||||
|
rlearn_output = rlearn_model.predict_rewards(batch)
|
||||||
|
rewind_output = rewind_model(commands, video_rewind)
|
||||||
|
|
||||||
|
print(f"RLearn output shape: {rlearn_output.shape}")
|
||||||
|
print(f"ReWiND output shape: {rewind_output.shape}")
|
||||||
|
|
||||||
|
if rlearn_output.shape == rewind_output.shape:
|
||||||
|
print("✓ Output shapes match")
|
||||||
|
else:
|
||||||
|
print("⚠ Output shapes differ - this may indicate implementation differences")
|
||||||
|
|
||||||
|
print("\nBenchmark completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_models()
|
||||||
File diff suppressed because one or more lines are too long
@@ -32,8 +32,7 @@ Input should be the current image or whole video and the task goal specified in
|
|||||||
Archiutecture:
|
Archiutecture:
|
||||||
_ inputs: video o1:T (or current o1:t), language z;
|
_ inputs: video o1:T (or current o1:t), language z;
|
||||||
_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
|
_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
|
||||||
_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding
|
\_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding \* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||||
\* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
|
||||||
|
|
||||||
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
||||||
|
|
||||||
@@ -56,11 +55,13 @@ _ Epic-Kitchens-100
|
|||||||
_ Something-Something v. 2 Dataset https://www.qualcomm.com/developer/software/something-something-v-2-dataset
|
_ Something-Something v. 2 Dataset https://www.qualcomm.com/developer/software/something-something-v-2-dataset
|
||||||
_ Ego4D (3000 hours)
|
_ Ego4D (3000 hours)
|
||||||
_ Open X-Embodiment (OXE)
|
_ Open X-Embodiment (OXE)
|
||||||
_ Age bot world: https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha
|
\_ Agi bot world: https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha
|
||||||
_ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
|
||||||
_ YouCook2 dataset
|
|
||||||
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
|
||||||
|
|
||||||
|
- GalexiAI dataset: https://opengalaxea.github.io/G0/
|
||||||
|
_ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
||||||
|
_ YouCook2 dataset
|
||||||
|
\_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||||
|
- Genie generated dataset?
|
||||||
|
|
||||||
### TODOs:
|
### TODOs:
|
||||||
|
|
||||||
@@ -77,11 +78,10 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
|||||||
- Only rewind loss [x]
|
- Only rewind loss [x]
|
||||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||||
- benchmark lucidrains vs this implementation forward pass []
|
- Test rewind (evaluate) [x]
|
||||||
- Test rewind (evaluate) []
|
- Cleanup code? []
|
||||||
- Cleanup code? []
|
- benchmark lucidrains vs this implementation forward pass, debug speed []
|
||||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||||
-----------------
|
|
||||||
- Then on 10 percent
|
- Then on 10 percent
|
||||||
- Ablation dino v2 vs dino v3 base 86 M
|
- Ablation dino v2 vs dino v3 base 86 M
|
||||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||||
@@ -90,4 +90,5 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
|||||||
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
|
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
|
||||||
- Extend evaluation []
|
- Extend evaluation []
|
||||||
- Add other datasets from OXE metioned in rewind []
|
- Add other datasets from OXE metioned in rewind []
|
||||||
- Ablation for size vision encoder, language encoder, temporal head
|
- Ablation for size vision encoder, language encoder, temporal head []
|
||||||
|
- Add other datasets metnioned here []
|
||||||
|
|||||||
Reference in New Issue
Block a user