From 3504d17fef2c20452f8c535606418185e02b7023 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 11:18:35 +0200 Subject: [PATCH] smaller siglip2 --- src/lerobot/policies/rlearn/configuration_rlearn.py | 4 ++-- src/lerobot/policies/rlearn/modeling_rlearn.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 5bb9141f7..3b7961db0 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -39,8 +39,8 @@ class RLearNConfig(PreTrainedConfig): """ # Encoders - Use SigLIP2 for both vision and text (shared checkpoint) - vision_model_name: str = "google/siglip2-base-patch16-512" - text_model_name: str = "google/siglip2-base-patch16-512" + vision_model_name: str = "google/siglip2-base-patch16-224" + text_model_name: str = "google/siglip2-base-patch16-224" freeze_backbones: bool = True # Sequence length, amount of past frames including current one to use in the temporal model diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 738bbf611..f2a70aa8f 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -207,7 +207,7 @@ class RLearNPolicy(PreTrainedPolicy): # No per-patch tokens available, synthesize single patch from pooler patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D) else: - raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output") + raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output") # Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean) try: @@ -224,7 +224,7 @@ class RLearNPolicy(PreTrainedPolicy): diff2 = mean_time_diff_4d(cand2) patch_features = cand1 if diff1 >= diff2 else cand2 if self.training and torch.rand(1).item() < 0.05: - print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") + print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") except Exception: # Fallback to default P = patch_tokens_flat.shape[1] @@ -233,7 +233,7 @@ class RLearNPolicy(PreTrainedPolicy): # DEBUG: Analyze vision feature variability (use per-frame pooled features for readability) if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging with torch.no_grad(): - print(f"\nšŸ” DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):") + print(f"\nšŸ” SigLIP2 VISION FEATURE DEBUG (B={B}, T={T}):") # CRITICAL: Check if input frames are actually different print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")