smaller siglip2

This commit is contained in:
Pepijn
2025-09-01 11:18:35 +02:00
parent d35ed3fd83
commit 3504d17fef
2 changed files with 5 additions and 5 deletions
@@ -39,8 +39,8 @@ class RLearNConfig(PreTrainedConfig):
"""
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
vision_model_name: str = "google/siglip2-base-patch16-512"
text_model_name: str = "google/siglip2-base-patch16-512"
vision_model_name: str = "google/siglip2-base-patch16-224"
text_model_name: str = "google/siglip2-base-patch16-224"
freeze_backbones: bool = True
# Sequence length, amount of past frames including current one to use in the temporal model
@@ -207,7 +207,7 @@ class RLearNPolicy(PreTrainedPolicy):
# No per-patch tokens available, synthesize single patch from pooler
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D)
else:
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output")
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
try:
@@ -224,7 +224,7 @@ class RLearNPolicy(PreTrainedPolicy):
diff2 = mean_time_diff_4d(cand2)
patch_features = cand1 if diff1 >= diff2 else cand2
if self.training and torch.rand(1).item() < 0.05:
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
except Exception:
# Fallback to default
P = patch_tokens_flat.shape[1]
@@ -233,7 +233,7 @@ class RLearNPolicy(PreTrainedPolicy):
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
with torch.no_grad():
print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):")
print(f"\n🔍 SigLIP2 VISION FEATURE DEBUG (B={B}, T={T}):")
# CRITICAL: Check if input frames are actually different
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")