siglip again

This commit is contained in:
Pepijn
2025-09-01 10:55:12 +02:00
parent 9dcb407ba7
commit ce5b27d255
+42 -39
View File
@@ -22,17 +22,6 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
# ReWiND dependencies
try:
from x_transformers import Decoder
import einx
from einops import rearrange, repeat, pack, unpack
except ImportError as e:
raise ImportError(
"ReWiND dependencies not installed. Please install: "
"pip install x-transformers einx einops x-mlps-pytorch"
) from e
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
@@ -95,17 +84,17 @@ class RLearNPolicy(PreTrainedPolicy):
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
# Cross-modal sequential aggregator causal transformer over
# [language tokens | video frame tokens]
self.decoder = Decoder(
dim=config.dim_model,
depth=config.num_layers,
heads=config.num_heads,
ff_mult=config.ff_mult,
attn_dropout=config.dropout,
ff_dropout=config.dropout,
cross_attend=False,
causal=True,
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.dim_model,
nhead=config.num_heads,
dim_feedforward=config.dim_model * config.ff_mult,
dropout=config.dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
# Per-frame predictor head
self.frame_mlp = nn.Sequential(
@@ -138,7 +127,7 @@ class RLearNPolicy(PreTrainedPolicy):
try:
self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
self.text_model = torch.compile(self.text_model, mode="reduce-overhead")
self.decoder = torch.compile(self.decoder, mode="reduce-overhead")
self.aggregator = torch.compile(self.aggregator, mode="reduce-overhead")
print("✅ Applied torch.compile to encoders and transformer")
except Exception as e:
print(f"⚠️ torch.compile failed: {e}")
@@ -182,7 +171,7 @@ class RLearNPolicy(PreTrainedPolicy):
(B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS)
"""
B, T, C, H, W = frames.shape
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
flat = frames.reshape(B * T, C, H, W)
# Optimized: Process tensor directly without numpy conversion
device = next(self.vision_model.parameters()).device
@@ -222,8 +211,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
try:
cand1 = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T)
cand2 = rearrange(patch_tokens_flat, '(t b) p d -> b t p d', t=T, b=B)
P = patch_tokens_flat.shape[1]
cand1 = patch_tokens_flat.reshape(B, T, P, -1)
cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3)
def mean_time_diff_4d(x):
if T <= 1:
return torch.tensor(0.0, device=x.device)
@@ -237,7 +227,8 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
except Exception:
# Fallback to default
patch_features = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T)
P = patch_tokens_flat.shape[1]
patch_features = patch_tokens_flat.view(B, T, P, -1)
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
@@ -297,7 +288,7 @@ class RLearNPolicy(PreTrainedPolicy):
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
# Recover CLS tokens
cls_flat = tokens[:, 0, :] # (BT, D)
cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T)
cls = cls_flat.view(B, T, -1)
b0 = 0
f0, f1 = 0, T - 1
# L2 between CLS at two frames
@@ -421,16 +412,26 @@ class RLearNPolicy(PreTrainedPolicy):
# First-frame positional embedding only
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
# Build attention mask for decoder (True = keep)
# Language mask from tokenizer, rest are fully valid
full_mask = F.pad(mask, (0, video_tokens.shape[1]), value=True)
# Pack and run transformer
# Build masks for TransformerEncoder
lang_valid = mask # (B, L) True where valid
video_valid = torch.ones(B, video_tokens.shape[1], device=device, dtype=torch.bool)
valid_mask = torch.cat([lang_valid, video_valid], dim=1) # (B, S)
key_padding_mask = ~valid_mask # True -> masked
tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D)
# Causal mask (S, S): True masks out future positions
S = tokens_seq.shape[1]
causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1)
transformer_start = time.perf_counter()
tokens_packed, packed_shape = pack((lang_tokens, video_tokens), 'b * d')
attended = self.decoder(tokens_packed, mask=full_mask)
attended_lang, attended_video = unpack(attended, packed_shape, 'b * d')
attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask)
transformer_time = time.perf_counter() - transformer_start
# Split back video part
L_len = lang_tokens.shape[1]
attended_video = attended_all[:, L_len:, :]
# Per-frame prediction
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
@@ -496,11 +497,13 @@ class RLearNPolicy(PreTrainedPolicy):
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Pack and forward
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, video_tokens), 'b * d')
mask_mm = F.pad(mask_mm, (0, video_tokens.shape[1]), value=True)
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
_, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
# Pack and forward with masks
lang_valid_mm = mask_mm
valid_mask_mm = torch.cat([lang_valid_mm, video_valid], dim=1)
key_padding_mask_mm = ~valid_mask_mm
tokens_seq_mm = torch.cat([lang_tokens_mm, video_tokens], dim=1)
attended_all_mm = self.aggregator(tokens_seq_mm, src_key_padding_mask=key_padding_mask_mm, mask=causal_mask)
attended_video_mm = attended_all_mm[:, L_len:, :]
# Process mismatch frames with single MLP
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)