mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
add pos info for all frames
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -50,7 +50,6 @@ class RLearNConfig(PreTrainedConfig):
|
||||
dim_feedforward: int = 2048
|
||||
dropout: float = 0.1
|
||||
pre_norm: bool = True
|
||||
use_first_frame_positional_bias: bool = True
|
||||
frame_dropout_p: float = 0.0
|
||||
stride: int = 1
|
||||
|
||||
|
||||
@@ -162,8 +162,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
|
||||
# Only first frame gets a positional embed (no cheating on progress)
|
||||
self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2)
|
||||
# Full positional encoding for all frames (helps learn temporal structure)
|
||||
# Using sinusoidal positional encoding for better temporal understanding
|
||||
pe = torch.zeros(config.max_seq_len, config.dim_model)
|
||||
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float()
|
||||
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
|
||||
-(math.log(10000.0) / config.dim_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
|
||||
self.first_pos_emb = None
|
||||
|
||||
# Register / memory / attention sink tokens
|
||||
self.num_register_tokens = config.num_register_tokens
|
||||
@@ -263,10 +271,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_embeds)
|
||||
|
||||
# Add first frame positional embedding
|
||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
||||
# Full positional encoding for temporal learning
|
||||
T_video = video_tokens.shape[1]
|
||||
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
||||
|
||||
# Pack all tokens for attention
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
@@ -407,10 +414,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_embeds)
|
||||
|
||||
# Add first frame positional embedding
|
||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
||||
|
||||
# Full positional encoding for temporal learning
|
||||
T_video = video_tokens.shape[1]
|
||||
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
||||
|
||||
# Pack all tokens for attention [lang | register | video]
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
|
||||
Reference in New Issue
Block a user