mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
smaller siglip2
This commit is contained in:
@@ -39,8 +39,8 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
|
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
|
||||||
vision_model_name: str = "google/siglip2-base-patch16-512"
|
vision_model_name: str = "google/siglip2-base-patch16-224"
|
||||||
text_model_name: str = "google/siglip2-base-patch16-512"
|
text_model_name: str = "google/siglip2-base-patch16-224"
|
||||||
freeze_backbones: bool = True
|
freeze_backbones: bool = True
|
||||||
|
|
||||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# No per-patch tokens available, synthesize single patch from pooler
|
# No per-patch tokens available, synthesize single patch from pooler
|
||||||
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D)
|
patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
|
raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output")
|
||||||
|
|
||||||
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
|
# Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean)
|
||||||
try:
|
try:
|
||||||
@@ -224,7 +224,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
diff2 = mean_time_diff_4d(cand2)
|
diff2 = mean_time_diff_4d(cand2)
|
||||||
patch_features = cand1 if diff1 >= diff2 else cand2
|
patch_features = cand1 if diff1 >= diff2 else cand2
|
||||||
if self.training and torch.rand(1).item() < 0.05:
|
if self.training and torch.rand(1).item() < 0.05:
|
||||||
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to default
|
# Fallback to default
|
||||||
P = patch_tokens_flat.shape[1]
|
P = patch_tokens_flat.shape[1]
|
||||||
@@ -233,7 +233,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
|
# DEBUG: Analyze vision feature variability (use per-frame pooled features for readability)
|
||||||
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):")
|
print(f"\n🔍 SigLIP2 VISION FEATURE DEBUG (B={B}, T={T}):")
|
||||||
|
|
||||||
# CRITICAL: Check if input frames are actually different
|
# CRITICAL: Check if input frames are actually different
|
||||||
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")
|
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")
|
||||||
|
|||||||
Reference in New Issue
Block a user