mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
add pos relative
This commit is contained in:
@@ -113,7 +113,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Configuration\n",
|
"# Configuration\n",
|
||||||
"DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n",
|
"DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n",
|
||||||
"MODEL_PATH = \"pepijn223/rlearn_mse\" # Change to your model checkpoint\n",
|
"MODEL_PATH = \"pepijn223/rlearn_mse0\" # Change to your model checkpoint\n",
|
||||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
|
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
|
||||||
"NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n",
|
"NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -134,7 +134,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -142,13 +142,13 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Setting up model...\n",
|
"Setting up model...\n",
|
||||||
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse\n"
|
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse0\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "0f302188511145498c49e7abf9408e5f",
|
"model_id": "f20c9f241f54445ca278d3c234ecb790",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@@ -179,7 +179,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "fcb2fe307e5846c7bc364dd3be14cd57",
|
"model_id": "d01c168730e9474cb4ae8e596a78bf26",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@@ -189,15 +189,6 @@
|
|||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "display_data"
|
"output_type": "display_data"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"✓ Model ready on mps\n",
|
|
||||||
" Parameters: 770,064,901\n",
|
|
||||||
" Trainable: 19,688,961\n"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
|||||||
@@ -162,16 +162,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||||
|
|
||||||
# Full positional encoding for all frames (helps learn temporal structure)
|
# Temporal positional encoding for window-relative positions only
|
||||||
# Using sinusoidal positional encoding for better temporal understanding
|
# This helps understand temporal order within 16-frame windows without enabling
|
||||||
pe = torch.zeros(config.max_seq_len, config.dim_model)
|
# episode-level progress cheating (since episodes are 100-300 frames)
|
||||||
position = torch.arange(0, config.max_seq_len).unsqueeze(1).float()
|
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.01)
|
||||||
div_term = torch.exp(torch.arange(0, config.dim_model, 2).float() *
|
|
||||||
-(math.log(10000.0) / config.dim_model))
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
self.pos_embedding = nn.Parameter(pe, requires_grad=True)
|
|
||||||
self.first_pos_emb = None
|
|
||||||
|
|
||||||
# Register / memory / attention sink tokens
|
# Register / memory / attention sink tokens
|
||||||
self.num_register_tokens = config.num_register_tokens
|
self.num_register_tokens = config.num_register_tokens
|
||||||
@@ -270,10 +264,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# Project embeddings
|
# Project embeddings
|
||||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||||
video_tokens = self.to_video_tokens(video_embeds)
|
video_tokens = self.to_video_tokens(video_embeds)
|
||||||
|
# Add temporal positional encoding (window-relative only)
|
||||||
# Full positional encoding for temporal learning
|
|
||||||
T_video = video_tokens.shape[1]
|
T_video = video_tokens.shape[1]
|
||||||
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
|
||||||
|
|
||||||
# Pack all tokens for attention
|
# Pack all tokens for attention
|
||||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||||
@@ -415,9 +408,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
video_tokens = self.to_video_tokens(video_embeds)
|
video_tokens = self.to_video_tokens(video_embeds)
|
||||||
|
|
||||||
|
|
||||||
# Full positional encoding for temporal learning
|
# Add temporal positional encoding (window-relative only)
|
||||||
T_video = video_tokens.shape[1]
|
T_video = video_tokens.shape[1]
|
||||||
video_tokens = video_tokens + self.pos_embedding[:T_video]
|
video_tokens = video_tokens + self.temporal_pos_embedding[:T_video]
|
||||||
|
|
||||||
# Pack all tokens for attention [lang | register | video]
|
# Pack all tokens for attention [lang | register | video]
|
||||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||||
|
|||||||
Reference in New Issue
Block a user