mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
remove full pos embedding
This commit is contained in:
@@ -80,7 +80,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
# ReWiND augmentation
|
||||
rewind_prob: float = 0.3 #0.8
|
||||
rewind_last3_prob: float = 0.0 #0.3
|
||||
mismatch_prob: float = 0.0# 0.2
|
||||
mismatch_prob: float = 0.0 #0.2
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
|
||||
@@ -86,10 +86,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# First-frame positional embedding (only applied to the first video frame)
|
||||
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
# Full temporal positional embeddings (length = max_seq_len)
|
||||
self.max_time = config.max_seq_len
|
||||
self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model))
|
||||
nn.init.trunc_normal_(self.time_pos, std=0.02)
|
||||
|
||||
# Cross-modal sequential aggregator – causal transformer over
|
||||
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
|
||||
@@ -330,9 +326,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# SigLIP2 CLS per-frame already returned
|
||||
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
|
||||
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
|
||||
# Add temporal positional embeddings
|
||||
video_tokens = video_tokens + self.time_pos[:, :T_eff, :]
|
||||
# Optional: keep a first-frame tag
|
||||
# ReWiND: only add a first-frame positional bias (prevents time cheating)
|
||||
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
|
||||
|
||||
# Build masks for TransformerEncoder
|
||||
|
||||
Reference in New Issue
Block a user