mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
siglip again
This commit is contained in:
@@ -22,17 +22,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
# ReWiND dependencies
|
||||
try:
|
||||
from x_transformers import Decoder
|
||||
import einx
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ReWiND dependencies not installed. Please install: "
|
||||
"pip install x-transformers einx einops x-mlps-pytorch"
|
||||
) from e
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
@@ -95,17 +84,17 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
|
||||
# Cross-modal sequential aggregator – causal transformer over
|
||||
# [language tokens | video frame tokens]
|
||||
self.decoder = Decoder(
|
||||
dim=config.dim_model,
|
||||
depth=config.num_layers,
|
||||
heads=config.num_heads,
|
||||
ff_mult=config.ff_mult,
|
||||
attn_dropout=config.dropout,
|
||||
ff_dropout=config.dropout,
|
||||
cross_attend=False,
|
||||
causal=True,
|
||||
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=config.dim_model,
|
||||
nhead=config.num_heads,
|
||||
dim_feedforward=config.dim_model * config.ff_mult,
|
||||
dropout=config.dropout,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
||||
|
||||
# Per-frame predictor head
|
||||
self.frame_mlp = nn.Sequential(
|
||||
@@ -138,7 +127,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
try:
|
||||
self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
|
||||
self.text_model = torch.compile(self.text_model, mode="reduce-overhead")
|
||||
self.decoder = torch.compile(self.decoder, mode="reduce-overhead")
|
||||
self.aggregator = torch.compile(self.aggregator, mode="reduce-overhead")
|
||||
print("✅ Applied torch.compile to encoders and transformer")
|
||||
except Exception as e:
|
||||
print(f"⚠️ torch.compile failed: {e}")
|
||||
@@ -182,7 +171,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
(B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS)
|
||||
"""
|
||||
B, T, C, H, W = frames.shape
|
||||
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
|
||||
flat = frames.reshape(B * T, C, H, W)
|
||||
|
||||
# Optimized: Process tensor directly without numpy conversion
|
||||
device = next(self.vision_model.parameters()).device
|
||||
@@ -222,8 +211,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
|
||||
try:
|
||||
cand1 = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T)
|
||||
cand2 = rearrange(patch_tokens_flat, '(t b) p d -> b t p d', t=T, b=B)
|
||||
P = patch_tokens_flat.shape[1]
|
||||
cand1 = patch_tokens_flat.reshape(B, T, P, -1)
|
||||
cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3)
|
||||
def mean_time_diff_4d(x):
|
||||
if T <= 1:
|
||||
return torch.tensor(0.0, device=x.device)
|
||||
@@ -237,7 +227,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
||||
except Exception:
|
||||
# Fallback to default
|
||||
patch_features = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T)
|
||||
P = patch_tokens_flat.shape[1]
|
||||
patch_features = patch_tokens_flat.view(B, T, P, -1)
|
||||
|
||||
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
|
||||
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
||||
@@ -297,7 +288,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
|
||||
# Recover CLS tokens
|
||||
cls_flat = tokens[:, 0, :] # (BT, D)
|
||||
cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
cls = cls_flat.view(B, T, -1)
|
||||
b0 = 0
|
||||
f0, f1 = 0, T - 1
|
||||
# L2 between CLS at two frames
|
||||
@@ -421,16 +412,26 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# First-frame positional embedding only
|
||||
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
|
||||
|
||||
# Build attention mask for decoder (True = keep)
|
||||
# Language mask from tokenizer, rest are fully valid
|
||||
full_mask = F.pad(mask, (0, video_tokens.shape[1]), value=True)
|
||||
# Pack and run transformer
|
||||
# Build masks for TransformerEncoder
|
||||
lang_valid = mask # (B, L) True where valid
|
||||
video_valid = torch.ones(B, video_tokens.shape[1], device=device, dtype=torch.bool)
|
||||
valid_mask = torch.cat([lang_valid, video_valid], dim=1) # (B, S)
|
||||
key_padding_mask = ~valid_mask # True -> masked
|
||||
|
||||
tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D)
|
||||
|
||||
# Causal mask (S, S): True masks out future positions
|
||||
S = tokens_seq.shape[1]
|
||||
causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
transformer_start = time.perf_counter()
|
||||
tokens_packed, packed_shape = pack((lang_tokens, video_tokens), 'b * d')
|
||||
attended = self.decoder(tokens_packed, mask=full_mask)
|
||||
attended_lang, attended_video = unpack(attended, packed_shape, 'b * d')
|
||||
attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask)
|
||||
transformer_time = time.perf_counter() - transformer_start
|
||||
|
||||
# Split back video part
|
||||
L_len = lang_tokens.shape[1]
|
||||
attended_video = attended_all[:, L_len:, :]
|
||||
|
||||
# Per-frame prediction
|
||||
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
|
||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
||||
@@ -496,11 +497,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, video_tokens), 'b * d')
|
||||
mask_mm = F.pad(mask_mm, (0, video_tokens.shape[1]), value=True)
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||
_, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
# Pack and forward with masks
|
||||
lang_valid_mm = mask_mm
|
||||
valid_mask_mm = torch.cat([lang_valid_mm, video_valid], dim=1)
|
||||
key_padding_mask_mm = ~valid_mask_mm
|
||||
tokens_seq_mm = torch.cat([lang_tokens_mm, video_tokens], dim=1)
|
||||
attended_all_mm = self.aggregator(tokens_seq_mm, src_key_padding_mask=key_padding_mask_mm, mask=causal_mask)
|
||||
attended_video_mm = attended_all_mm[:, L_len:, :]
|
||||
|
||||
# Process mismatch frames with single MLP
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
|
||||
Reference in New Issue
Block a user