add pos relative

This commit is contained in:
Pepijn
2025-08-31 00:43:26 +02:00
parent 195cc79c49
commit be9bdc242f
2 changed files with 14 additions and 30 deletions
+5 -14
View File
@@ -113,7 +113,7 @@
"source": [
"# Configuration\n",
"DATASET_REPO = \"pepijn223/phone_pipeline_pickup1\" # Change to your dataset\n",
"MODEL_PATH = \"pepijn223/rlearn_mse\" # Change to your model checkpoint\n",
"MODEL_PATH = \"pepijn223/rlearn_mse0\" # Change to your model checkpoint\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
"NUM_EVAL_EPISODES = 10 # Number of episodes for evaluation\n",
"\n",
@@ -134,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -142,13 +142,13 @@
"output_type": "stream",
"text": [
"Setting up model...\n",
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse\n"
"Loading trained model from Hugging Face Hub: pepijn223/rlearn_mse0\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f302188511145498c49e7abf9408e5f",
"model_id": "f20c9f241f54445ca278d3c234ecb790",
"version_major": 2,
"version_minor": 0
},
@@ -179,7 +179,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fcb2fe307e5846c7bc364dd3be14cd57",
"model_id": "d01c168730e9474cb4ae8e596a78bf26",
"version_major": 2,
"version_minor": 0
},
@@ -189,15 +189,6 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Model ready on mps\n",
" Parameters: 770,064,901\n",
" Trainable: 19,688,961\n"
]
}
],
"source": [