remove full pos embedding

This commit is contained in:
Pepijn
2025-09-01 14:51:33 +02:00
parent cf0c3f0a9a
commit 88116b11e1
2 changed files with 2 additions and 8 deletions
@@ -86,10 +86,6 @@ class RLearNPolicy(PreTrainedPolicy):
# First-frame positional embedding (only applied to the first video frame) # First-frame positional embedding (only applied to the first video frame)
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model)) self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
# Full temporal positional embeddings (length = max_seq_len)
self.max_time = config.max_seq_len
self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model))
nn.init.trunc_normal_(self.time_pos, std=0.02)
# Cross-modal sequential aggregator causal transformer over # Cross-modal sequential aggregator causal transformer over
# [language tokens | video frame tokens] using PyTorch TransformerEncoder # [language tokens | video frame tokens] using PyTorch TransformerEncoder
@@ -330,9 +326,7 @@ class RLearNPolicy(PreTrainedPolicy):
# SigLIP2 CLS per-frame already returned # SigLIP2 CLS per-frame already returned
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision) video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D) video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
# Add temporal positional embeddings # ReWiND: only add a first-frame positional bias (prevents time cheating)
video_tokens = video_tokens + self.time_pos[:, :T_eff, :]
# Optional: keep a first-frame tag
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
# Build masks for TransformerEncoder # Build masks for TransformerEncoder