This commit is contained in:
Pepijn
2025-08-30 12:05:38 +02:00
parent f5c39d6292
commit a4fc02a636
@@ -34,7 +34,7 @@ class RLearNConfig(PreTrainedConfig):
Notes: Notes:
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders - This follows the ReWiND paper architecture. It uses frozen vision/text encoders
(DINO v3 for vision, sentence-transformers for language) and trains a (DINO v3 for vision, sentence-transformers for language) and trains a
lightweight temporal aggregator + head. lightweight temporal aggregator + head.
""" """
@@ -61,15 +61,17 @@ class RLearNConfig(PreTrainedConfig):
use_tanh_head: bool = False # when True, bound outputs in [-1, 1] use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
# Training # Training
learning_rate: float = 1e-4 learning_rate: float = 3e-5
weight_decay: float = 0.01 weight_decay: float = 0.01
# ReWiND-specific parameters # ReWiND-specific parameters
use_video_rewind: bool = True # Enable video rewinding augmentation use_video_rewind: bool = False # Enable video rewinding augmentation
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%) rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
use_mismatch_loss: bool = True # Enable mismatched language-video loss use_mismatch_loss: bool = False # Enable mismatched language-video loss
mismatch_prob: float = 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%) mismatch_prob: float = (
0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
)
# Loss hyperparameters (simplified for ReWiND) # Loss hyperparameters (simplified for ReWiND)
# The main loss is just MSE between predicted and target progress # The main loss is just MSE between predicted and target progress
@@ -85,7 +87,7 @@ class RLearNConfig(PreTrainedConfig):
# Architectural knobs to better mirror ReWiND # Architectural knobs to better mirror ReWiND
num_register_tokens: int = 4 num_register_tokens: int = 4
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
# HLGauss loss parameters # HLGauss loss parameters
use_hl_gauss_loss: bool = True use_hl_gauss_loss: bool = True
reward_min_value: float = 0.0 reward_min_value: float = 0.0