use cls token

This commit is contained in:
Pepijn
2025-09-01 11:31:28 +02:00
parent 3504d17fef
commit 9a19f8f6f4
2 changed files with 33 additions and 42 deletions
@@ -67,7 +67,7 @@ class RLearNConfig(PreTrainedConfig):
logit_eps: float = 1e-4
# Performance optimizations
use_amp: bool = True
use_amp: bool = False
compile_model: bool = True
# ReWiND augmentation
+32 -41
View File
@@ -162,13 +162,13 @@ class RLearNPolicy(PreTrainedPolicy):
raise NotImplementedError("RLearN is a reward model and does not select actions")
def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through DinoV3 to get per-frame PATCH embeddings.
"""Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.
Args:
frames: (B, T, C, H, W)
Returns:
(B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS)
(B, T, D_vision) CLS token per frame
"""
B, T, C, H, W = frames.shape
flat = frames.reshape(B * T, C, H, W)
@@ -195,40 +195,38 @@ class RLearNPolicy(PreTrainedPolicy):
# Process in batch through SigLIP2 vision tower
vision_outputs = self.vision_model(**inputs)
# Prefer patch tokens from last_hidden_state (exclude CLS at index 0)
# Prefer CLS token from last_hidden_state at index 0
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D)
if tokens.dim() == 3 and tokens.shape[1] > 1:
patch_tokens_flat = tokens[:, 1:, :] # (BT, P, D)
if tokens.dim() == 3 and tokens.shape[1] >= 1:
cls_tokens_flat = tokens[:, 0, :] # (BT, D)
else:
# Only one token available → treat as single patch
patch_tokens_flat = tokens[:, :1, :]
# Fallback to pooler if structure unexpected
cls_tokens_flat = getattr(vision_outputs, 'pooler_output')
elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
# No per-patch tokens available, synthesize single patch from pooler
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D)
# Use pooled output
cls_tokens_flat = vision_outputs.pooler_output # (BT, D)
else:
raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output")
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
# Robustly reshape CLS to (B, T, D): detect correct flatten order by maximizing temporal variance
try:
P = patch_tokens_flat.shape[1]
cand1 = patch_tokens_flat.reshape(B, T, P, -1)
cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3)
def mean_time_diff_4d(x):
D = cls_tokens_flat.shape[-1]
cand1 = cls_tokens_flat.reshape(B, T, D)
cand2 = cls_tokens_flat.reshape(T, B, D).permute(1, 0, 2)
def mean_time_diff_3d(x):
if T <= 1:
return torch.tensor(0.0, device=x.device)
x_mean = x.mean(dim=2) # (B, T, D)
diffs = (x_mean[:, 1:, :] - x_mean[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
return diffs.mean()
diff1 = mean_time_diff_4d(cand1)
diff2 = mean_time_diff_4d(cand2)
patch_features = cand1 if diff1 >= diff2 else cand2
diff1 = mean_time_diff_3d(cand1)
diff2 = mean_time_diff_3d(cand2)
frame_features = cand1 if diff1 >= diff2 else cand2
if self.training and torch.rand(1).item() < 0.05:
print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
except Exception:
# Fallback to default
P = patch_tokens_flat.shape[1]
patch_features = patch_tokens_flat.view(B, T, P, -1)
frame_features = cls_tokens_flat.view(B, T, -1)
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
@@ -277,8 +275,8 @@ class RLearNPolicy(PreTrainedPolicy):
else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# Check feature statistics (pooled over patches)
vision_features = patch_features.mean(dim=2) # (B, T, D)
# Check feature statistics
vision_features = frame_features # (B, T, D)
feature_mean = vision_features.mean().item()
feature_std = vision_features.std().item()
print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}")
@@ -286,21 +284,13 @@ class RLearNPolicy(PreTrainedPolicy):
# Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames
try:
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
# Recover CLS tokens
cls_flat = tokens[:, 0, :] # (BT, D)
cls = cls_flat.view(B, T, -1)
# Recover CLS tokens (already computed as frame_features)
cls = frame_features
b0 = 0
f0, f1 = 0, T - 1
# L2 between CLS at two frames
cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item()
# Patch mean L2
pm_f0 = patch_features[b0, f0].mean(dim=0)
pm_f1 = patch_features[b0, f1].mean(dim=0)
pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item()
# Max over patches L2
per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt()
max_p_l2 = per_patch_l2.max().item()
print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}")
print(f"CLS ΔL2: {cls_l2:.6f}")
except Exception as _:
pass
@@ -344,7 +334,7 @@ class RLearNPolicy(PreTrainedPolicy):
print("=" * 50)
return patch_features
return frame_features
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Compute ReWiND training loss with on-the-fly progress label generation.
@@ -406,8 +396,8 @@ class RLearNPolicy(PreTrainedPolicy):
# Token preparation
# Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D)
# Collapse patches to per-frame tokens then project
video_frame_embeds = video_patch_embeds.mean(dim=2) # (B, T_eff, D_vision)
# SigLIP2 CLS per-frame already returned
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
# First-frame positional embedding only
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
@@ -658,7 +648,7 @@ class RLearNPolicy(PreTrainedPolicy):
"""Return (embeddings, mask) for language tokens using SigLIP2.
embeddings: (B, L, D); mask: (B, L) True for valid tokens.
"""
# Optimized: Process all commands in batch (much faster than individual processing)
# Optimized: Process all commands in batch and take CLS token
proc = self.processor(
text=commands,
return_tensors='pt',
@@ -688,9 +678,10 @@ class RLearNPolicy(PreTrainedPolicy):
# Batch encode through text model
outputs = self.text_model.text_model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state
mask = attention_mask.bool()
return last_hidden, mask
# Use CLS token (position 0) as single language token
cls_only = outputs.last_hidden_state[:, :1, :]
mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
return cls_only, mask
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.