use cls token

This commit is contained in:
Pepijn
2025-09-01 11:31:28 +02:00
parent 3504d17fef
commit 9a19f8f6f4
2 changed files with 33 additions and 42 deletions
@@ -67,7 +67,7 @@ class RLearNConfig(PreTrainedConfig):
logit_eps: float = 1e-4 logit_eps: float = 1e-4
# Performance optimizations # Performance optimizations
use_amp: bool = True use_amp: bool = False
compile_model: bool = True compile_model: bool = True
# ReWiND augmentation # ReWiND augmentation
+32 -41
View File
@@ -162,13 +162,13 @@ class RLearNPolicy(PreTrainedPolicy):
raise NotImplementedError("RLearN is a reward model and does not select actions") raise NotImplementedError("RLearN is a reward model and does not select actions")
def _encode_video_frames(self, frames: Tensor) -> Tensor: def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through DinoV3 to get per-frame PATCH embeddings. """Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.
Args: Args:
frames: (B, T, C, H, W) frames: (B, T, C, H, W)
Returns: Returns:
(B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS) (B, T, D_vision) CLS token per frame
""" """
B, T, C, H, W = frames.shape B, T, C, H, W = frames.shape
flat = frames.reshape(B * T, C, H, W) flat = frames.reshape(B * T, C, H, W)
@@ -195,40 +195,38 @@ class RLearNPolicy(PreTrainedPolicy):
# Process in batch through SigLIP2 vision tower # Process in batch through SigLIP2 vision tower
vision_outputs = self.vision_model(**inputs) vision_outputs = self.vision_model(**inputs)
# Prefer patch tokens from last_hidden_state (exclude CLS at index 0) # Prefer CLS token from last_hidden_state at index 0
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D) tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D)
if tokens.dim() == 3 and tokens.shape[1] > 1: if tokens.dim() == 3 and tokens.shape[1] >= 1:
patch_tokens_flat = tokens[:, 1:, :] # (BT, P, D) cls_tokens_flat = tokens[:, 0, :] # (BT, D)
else: else:
# Only one token available → treat as single patch # Fallback to pooler if structure unexpected
patch_tokens_flat = tokens[:, :1, :] cls_tokens_flat = getattr(vision_outputs, 'pooler_output')
elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
# No per-patch tokens available, synthesize single patch from pooler # Use pooled output
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D) cls_tokens_flat = vision_outputs.pooler_output # (BT, D)
else: else:
raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output") raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output")
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean) # Robustly reshape CLS to (B, T, D): detect correct flatten order by maximizing temporal variance
try: try:
P = patch_tokens_flat.shape[1] D = cls_tokens_flat.shape[-1]
cand1 = patch_tokens_flat.reshape(B, T, P, -1) cand1 = cls_tokens_flat.reshape(B, T, D)
cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3) cand2 = cls_tokens_flat.reshape(T, B, D).permute(1, 0, 2)
def mean_time_diff_4d(x): def mean_time_diff_3d(x):
if T <= 1: if T <= 1:
return torch.tensor(0.0, device=x.device) return torch.tensor(0.0, device=x.device)
x_mean = x.mean(dim=2) # (B, T, D) diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
diffs = (x_mean[:, 1:, :] - x_mean[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
return diffs.mean() return diffs.mean()
diff1 = mean_time_diff_4d(cand1) diff1 = mean_time_diff_3d(cand1)
diff2 = mean_time_diff_4d(cand2) diff2 = mean_time_diff_3d(cand2)
patch_features = cand1 if diff1 >= diff2 else cand2 frame_features = cand1 if diff1 >= diff2 else cand2
if self.training and torch.rand(1).item() < 0.05: if self.training and torch.rand(1).item() < 0.05:
print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
except Exception: except Exception:
# Fallback to default # Fallback to default
P = patch_tokens_flat.shape[1] frame_features = cls_tokens_flat.view(B, T, -1)
patch_features = patch_tokens_flat.view(B, T, P, -1)
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability) # DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
@@ -277,8 +275,8 @@ class RLearNPolicy(PreTrainedPolicy):
else: else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# Check feature statistics (pooled over patches) # Check feature statistics
vision_features = patch_features.mean(dim=2) # (B, T, D) vision_features = frame_features # (B, T, D)
feature_mean = vision_features.mean().item() feature_mean = vision_features.mean().item()
feature_std = vision_features.std().item() feature_std = vision_features.std().item()
print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}") print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}")
@@ -286,21 +284,13 @@ class RLearNPolicy(PreTrainedPolicy):
# Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames # Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames
try: try:
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2: if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
# Recover CLS tokens # Recover CLS tokens (already computed as frame_features)
cls_flat = tokens[:, 0, :] # (BT, D) cls = frame_features
cls = cls_flat.view(B, T, -1)
b0 = 0 b0 = 0
f0, f1 = 0, T - 1 f0, f1 = 0, T - 1
# L2 between CLS at two frames # L2 between CLS at two frames
cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item() cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item()
# Patch mean L2 print(f"CLS ΔL2: {cls_l2:.6f}")
pm_f0 = patch_features[b0, f0].mean(dim=0)
pm_f1 = patch_features[b0, f1].mean(dim=0)
pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item()
# Max over patches L2
per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt()
max_p_l2 = per_patch_l2.max().item()
print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}")
except Exception as _: except Exception as _:
pass pass
@@ -344,7 +334,7 @@ class RLearNPolicy(PreTrainedPolicy):
print("=" * 50) print("=" * 50)
return patch_features return frame_features
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Compute ReWiND training loss with on-the-fly progress label generation. """Compute ReWiND training loss with on-the-fly progress label generation.
@@ -406,8 +396,8 @@ class RLearNPolicy(PreTrainedPolicy):
# Token preparation # Token preparation
# Project embeddings # Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D) lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D)
# Collapse patches to per-frame tokens then project # SigLIP2 CLS per-frame already returned
video_frame_embeds = video_patch_embeds.mean(dim=2) # (B, T_eff, D_vision) video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D) video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
# First-frame positional embedding only # First-frame positional embedding only
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
@@ -658,7 +648,7 @@ class RLearNPolicy(PreTrainedPolicy):
"""Return (embeddings, mask) for language tokens using SigLIP2. """Return (embeddings, mask) for language tokens using SigLIP2.
embeddings: (B, L, D); mask: (B, L) True for valid tokens. embeddings: (B, L, D); mask: (B, L) True for valid tokens.
""" """
# Optimized: Process all commands in batch (much faster than individual processing) # Optimized: Process all commands in batch and take CLS token
proc = self.processor( proc = self.processor(
text=commands, text=commands,
return_tensors='pt', return_tensors='pt',
@@ -688,9 +678,10 @@ class RLearNPolicy(PreTrainedPolicy):
# Batch encode through text model # Batch encode through text model
outputs = self.text_model.text_model(input_ids=input_ids, attention_mask=attention_mask) outputs = self.text_model.text_model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state # Use CLS token (position 0) as single language token
mask = attention_mask.bool() cls_only = outputs.last_hidden_state[:, :1, :]
return last_hidden, mask mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
return cls_only, mask
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]: def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data. """Try to extract (episode_index, frame_index) tensors from batch or complementary data.