add pos info for all frames

This commit is contained in:
Pepijn
2025-08-31 00:29:08 +02:00
parent f8d42cc038
commit 195cc79c49
3 changed files with 51 additions and 85 deletions
File diff suppressed because one or more lines are too long
@@ -50,7 +50,6 @@ class RLearNConfig(PreTrainedConfig):
dim_feedforward: int = 2048
dropout: float = 0.1
pre_norm: bool = True
use_first_frame_positional_bias: bool = True
frame_dropout_p: float = 0.0
stride: int = 1
+17 -10
View File
@@ -162,8 +162,16 @@ class RLearNPolicy(PreTrainedPolicy):
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
# Only first frame gets a positional embed (no cheating on progress)
self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2)
# Full positional encoding for all frames (helps learn temporal structure)
# Using sinusoidal positional encoding for better temporal understanding
pe = torch.zeros(config.max_seq_len, config.dim_model)
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
-(math.log(10000.0) / config.dim_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
self.first_pos_emb = None
# Register / memory / attention sink tokens
self.num_register_tokens = config.num_register_tokens
@@ -263,10 +271,9 @@ class RLearNPolicy(PreTrainedPolicy):
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# Add first frame positional embedding
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
# Full positional encoding for temporal learning
T_video = video_tokens.shape[1]
video_tokens = video_tokens + self.pos_embedding[:T_video]
# Pack all tokens for attention
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
@@ -407,10 +414,10 @@ class RLearNPolicy(PreTrainedPolicy):
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# Add first frame positional embedding
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
# Full positional encoding for temporal learning
T_video = video_tokens.shape[1]
video_tokens = video_tokens + self.pos_embedding[:T_video]
# Pack all tokens for attention [lang | register | video]
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')