fix dinov3

This commit is contained in:
Pepijn
2025-08-31 19:44:27 +02:00
parent e1d433cbfc
commit 79c3466f0f
@@ -377,9 +377,7 @@ class RLearNPolicy(PreTrainedPolicy):
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through DINOv3 model
# Use inference mode for stable, fast frozen encoder forward
with torch.inference_mode():
vision_outputs = self.vision_model(**inputs)
vision_outputs = self.vision_model(**inputs)
# Prefer mean-pooled patch tokens over pooler/CLS to ensure input-dependent variation
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: