remove full pos embedding

This commit is contained in:
Pepijn
2025-09-01 14:51:33 +02:00
parent cf0c3f0a9a
commit 88116b11e1
2 changed files with 2 additions and 8 deletions
@@ -80,7 +80,7 @@ class RLearNConfig(PreTrainedConfig):
# ReWiND augmentation
rewind_prob: float = 0.3 #0.8
rewind_last3_prob: float = 0.0 #0.3
mismatch_prob: float = 0.0# 0.2
mismatch_prob: float = 0.0 #0.2
# Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -86,10 +86,6 @@ class RLearNPolicy(PreTrainedPolicy):
# First-frame positional embedding (only applied to the first video frame)
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
# Full temporal positional embeddings (length = max_seq_len)
self.max_time = config.max_seq_len
self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model))
nn.init.trunc_normal_(self.time_pos, std=0.02)
# Cross-modal sequential aggregator causal transformer over
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
@@ -330,9 +326,7 @@ class RLearNPolicy(PreTrainedPolicy):
# SigLIP2 CLS per-frame already returned
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
# Add temporal positional embeddings
video_tokens = video_tokens + self.time_pos[:, :T_eff, :]
# Optional: keep a first-frame tag
# ReWiND: only add a first-frame positional bias (prevents time cheating)
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
# Build masks for TransformerEncoder