mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
add pos info for all frames
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -50,7 +50,6 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
dim_feedforward: int = 2048
|
dim_feedforward: int = 2048
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
pre_norm: bool = True
|
pre_norm: bool = True
|
||||||
use_first_frame_positional_bias: bool = True
|
|
||||||
frame_dropout_p: float = 0.0
|
frame_dropout_p: float = 0.0
|
||||||
stride: int = 1
|
stride: int = 1
|
||||||
|
|
||||||
|
|||||||
@@ -162,8 +162,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||||
|
|
||||||
# Only first frame gets a positional embed (no cheating on progress)
|
# Full positional encoding for all frames (helps learn temporal structure)
|
||||||
self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2)
|
# Using sinusoidal positional encoding for better temporal understanding
|
||||||
|
pe = torch.zeros(config.max_seq_len, config.dim_model)
|
||||||
|
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float()
|
||||||
|
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
|
||||||
|
-(math.log(10000.0) / config.dim_model))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
|
||||||
|
self.first_pos_emb = None
|
||||||
|
|
||||||
# Register / memory / attention sink tokens
|
# Register / memory / attention sink tokens
|
||||||
self.num_register_tokens = config.num_register_tokens
|
self.num_register_tokens = config.num_register_tokens
|
||||||
@@ -263,10 +271,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||||
video_tokens = self.to_video_tokens(video_embeds)
|
video_tokens = self.to_video_tokens(video_embeds)
|
||||||
|
|
||||||
# Add first frame positional embedding
|
# Full positional encoding for temporal learning
|
||||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
T_video = video_tokens.shape[1]
|
||||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
||||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
|
||||||
|
|
||||||
# Pack all tokens for attention
|
# Pack all tokens for attention
|
||||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||||
@@ -407,10 +414,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||||
video_tokens = self.to_video_tokens(video_embeds)
|
video_tokens = self.to_video_tokens(video_embeds)
|
||||||
|
|
||||||
# Add first frame positional embedding
|
|
||||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
# Full positional encoding for temporal learning
|
||||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
T_video = video_tokens.shape[1]
|
||||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
||||||
|
|
||||||
# Pack all tokens for attention [lang | register | video]
|
# Pack all tokens for attention [lang | register | video]
|
||||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||||
|
|||||||
Reference in New Issue
Block a user