add pos relative

This commit is contained in:
Pepijn
2025-08-31 00:43:26 +02:00
parent 195cc79c49
commit be9bdc242f
2 changed files with 14 additions and 30 deletions
+5 -14
View File
@@ -113,7 +113,7 @@
"source": [
"# Configuration\n",
"DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n",
"MODEL_PATH = \"pepijn223/rlearn_mse\" # Change to your model checkpoint\n",
"MODEL_PATH = \"pepijn223/rlearn_mse0\" # Change to your model checkpoint\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
"NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n",
"\n",
@@ -134,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -142,13 +142,13 @@
"output_type": "stream",
"text": [
"Setting up model...\n",
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse\n"
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse0\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f302188511145498c49e7abf9408e5f",
"model_id": "f20c9f241f54445ca278d3c234ecb790",
"version_major": 2,
"version_minor": 0
},
@@ -179,7 +179,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fcb2fe307e5846c7bc364dd3be14cd57",
"model_id": "d01c168730e9474cb4ae8e596a78bf26",
"version_major": 2,
"version_minor": 0
},
@@ -189,15 +189,6 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Model ready on mps\n",
" Parameters: 770,064,901\n",
" Trainable: 19,688,961\n"
]
}
],
"source": [
+8 -15
View File
@@ -162,16 +162,10 @@ class RLearNPolicy(PreTrainedPolicy):
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
# Full positional encoding for all frames (helps learn temporal structure)
# Using sinusoidal positional encoding for better temporal understanding
pe = torch.zeros(config.max_seq_len, config.dim_model)
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
-(math.log(10000.0) / config.dim_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
self.first_pos_emb = None
# Temporal positional encoding for window-relative positions only
# This helps understand temporal order within 16-frame windows without enabling
# episode-level progress cheating (since episodes are 100-300 frames)
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.01)
# Register / memory / attention sink tokens
self.num_register_tokens = config.num_register_tokens
@@ -270,10 +264,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# Full positional encoding for temporal learning
# Add temporal positional encoding (window-relative only)
T_video = video_tokens.shape[1]
video_tokens = video_tokens + self.pos_embedding[:T_video]
video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
# Pack all tokens for attention
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
@@ -415,9 +408,9 @@ class RLearNPolicy(PreTrainedPolicy):
video_tokens = self.to_video_tokens(video_embeds)
# Full positional encoding for temporal learning
# Add temporal positional encoding (window-relative only)
T_video = video_tokens.shape[1]
video_tokens = video_tokens + self.pos_embedding[:T_video]
video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
# Pack all tokens for attention [lang | register | video]
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')