add pos relative

This commit is contained in:
Pepijn
2025-08-31 00:43:26 +02:00
parent 195cc79c49
commit be9bdc242f
2 changed files with 14 additions and 30 deletions
+5 -14
View File
@@ -113,7 +113,7 @@
"source": [ "source": [
"# Configuration\n", "# Configuration\n",
"DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n", "DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n",
"MODEL_PATH = \"pepijn223/rlearn_mse\" # Change to your model checkpoint\n", "MODEL_PATH = \"pepijn223/rlearn_mse0\" # Change to your model checkpoint\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
"NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n", "NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n",
"\n", "\n",
@@ -134,7 +134,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -142,13 +142,13 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Setting up model...\n", "Setting up model...\n",
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse\n" "Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse0\n"
] ]
}, },
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "0f302188511145498c49e7abf9408e5f", "model_id": "f20c9f241f54445ca278d3c234ecb790",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@@ -179,7 +179,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "fcb2fe307e5846c7bc364dd3be14cd57", "model_id": "d01c168730e9474cb4ae8e596a78bf26",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@@ -189,15 +189,6 @@
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Model ready on mps\n",
" Parameters: 770,064,901\n",
" Trainable: 19,688,961\n"
]
} }
], ],
"source": [ "source": [
+9 -16
View File
@@ -162,16 +162,10 @@ class RLearNPolicy(PreTrainedPolicy):
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model) self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model) self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
# Full positional encoding for all frames (helps learn temporal structure) # Temporal positional encoding for window-relative positions only
# Using sinusoidal positional encoding for better temporal understanding # This helps understand temporal order within 16-frame windows without enabling
pe = torch.zeros(config.max_seq_len, config.dim_model) # episode-level progress cheating (since episodes are 100-300 frames)
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float() self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.01)
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
-(math.log(10000.0) / config.dim_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
self.first_pos_emb = None
# Register / memory / attention sink tokens # Register / memory / attention sink tokens
self.num_register_tokens = config.num_register_tokens self.num_register_tokens = config.num_register_tokens
@@ -270,10 +264,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Project embeddings # Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds) lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds) video_tokens = self.to_video_tokens(video_embeds)
# Add temporal positional encoding (window-relative only)
# Full positional encoding for temporal learning T_video = video_tokens.shape[1]
T_video = video_tokens.shape[1] video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
video_tokens = video_tokens + self.pos_embedding[:T_video]
# Pack all tokens for attention # Pack all tokens for attention
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
@@ -415,9 +408,9 @@ class RLearNPolicy(PreTrainedPolicy):
video_tokens = self.to_video_tokens(video_embeds) video_tokens = self.to_video_tokens(video_embeds)
# Full positional encoding for temporal learning # Add temporal positional encoding (window-relative only)
T_video = video_tokens.shape[1] T_video = video_tokens.shape[1]
video_tokens = video_tokens + self.pos_embedding[:T_video] video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
# Pack all tokens for attention [lang | register | video] # Pack all tokens for attention [lang | register | video]
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')