mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
smaller siglip2
This commit is contained in:
@@ -39,8 +39,8 @@ class RLearNConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
|
||||
vision_model_name: str = "google/siglip2-base-patch16-512"
|
||||
text_model_name: str = "google/siglip2-base-patch16-512"
|
||||
vision_model_name: str = "google/siglip2-base-patch16-224"
|
||||
text_model_name: str = "google/siglip2-base-patch16-224"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
|
||||
@@ -207,7 +207,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# No per-patch tokens available, synthesize single patch from pooler
|
||||
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D)
|
||||
else:
|
||||
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
|
||||
raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output")
|
||||
|
||||
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
|
||||
try:
|
||||
@@ -224,7 +224,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
diff2 = mean_time_diff_4d(cand2)
|
||||
patch_features = cand1 if diff1 >= diff2 else cand2
|
||||
if self.training and torch.rand(1).item() < 0.05:
|
||||
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
||||
print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
||||
except Exception:
|
||||
# Fallback to default
|
||||
P = patch_tokens_flat.shape[1]
|
||||
@@ -233,7 +233,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
|
||||
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
||||
with torch.no_grad():
|
||||
print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):")
|
||||
print(f"\n🔍 SigLIP2 VISION FEATURE DEBUG (B={B}, T={T}):")
|
||||
|
||||
# CRITICAL: Check if input frames are actually different
|
||||
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")
|
||||
|
||||
Reference in New Issue
Block a user