From be9bdc242ff6206db2f416a21b212a9db2f76c15 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 00:43:26 +0200 Subject: [PATCH] add pos relative --- notebooks/rlearn_evaluation.ipynb | 19 ++++---------- .../policies/rlearn/modeling_rlearn.py | 25 +++++++------------ 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/notebooks/rlearn_evaluation.ipynb b/notebooks/rlearn_evaluation.ipynb index 508a42934..9f6e46c76 100644 --- a/notebooks/rlearn_evaluation.ipynb +++ b/notebooks/rlearn_evaluation.ipynb @@ -113,7 +113,7 @@ "source": [ "# Configuration\n", "DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n", - "MODEL_PATH = \"pepijn223/rlearn_mse\" # Change to your model checkpoint\n", + "MODEL_PATH = \"pepijn223/rlearn_mse0\" # Change to your model checkpoint\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n", "NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n", "\n", @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -142,13 +142,13 @@ "output_type": "stream", "text": [ "Setting up model...\n", - "Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse\n" + "Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse0\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0f302188511145498c49e7abf9408e5f", + "model_id": "f20c9f241f54445ca278d3c234ecb790", "version_major": 2, "version_minor": 0 }, @@ -179,7 +179,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fcb2fe307e5846c7bc364dd3be14cd57", + "model_id": "d01c168730e9474cb4ae8e596a78bf26", "version_major": 2, "version_minor": 0 }, @@ -189,15 +189,6 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "✓ Model ready on mps\n", - " Parameters: 770,064,901\n", - " Trainable: 19,688,961\n" - ] } ], "source": [ diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index a8535f418..b56a8f089 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -162,16 +162,10 @@ class RLearNPolicy(PreTrainedPolicy): self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model) self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model) - # Full positional encoding for all frames (helps learn temporal structure) - # Using sinusoidal positional encoding for better temporal understanding - pe = torch.zeros(config.max_seq_len, config.dim_model) - position = torch.arange(0, config.max_seq_len).unsqueeze(1).float() - div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() * - -(math.log(10000.0) / config.dim_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - self.pos_embedding = nn.Parameter(pe, requires_grad=True) - self.first_pos_emb = None + # Temporal positional encoding for window-relative positions only + # This helps understand temporal order within 16-frame windows without enabling + # episode-level progress cheating (since episodes are 100-300 frames) + self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.01) # Register / memory / attention sink tokens self.num_register_tokens = config.num_register_tokens @@ -270,10 +264,9 @@ class RLearNPolicy(PreTrainedPolicy): # Project embeddings lang_tokens = self.to_lang_tokens(lang_embeds) video_tokens = self.to_video_tokens(video_embeds) - - # Full positional encoding for temporal learning - T_video = video_tokens.shape[1] - video_tokens = video_tokens + self.pos_embedding[:T_video] + # Add temporal positional encoding (window-relative only) + T_video = video_tokens.shape[1] + video_tokens = video_tokens + self.temporal_pos_embedding[:T_video] # Pack all tokens for attention tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') @@ -415,9 +408,9 @@ class RLearNPolicy(PreTrainedPolicy): video_tokens = self.to_video_tokens(video_embeds) - # Full positional encoding for temporal learning + # Add temporal positional encoding (window-relative only) T_video = video_tokens.shape[1] - video_tokens = video_tokens + self.pos_embedding[:T_video] + video_tokens = video_tokens + self.temporal_pos_embedding[:T_video] # Pack all tokens for attention [lang | register | video] tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')