fix(profiling): fix pi0 cuBLAS error and pi05 OOM on 22GB GPU

- Move cudnn_deterministic to per-spec train_args instead of hardcoding
  it for all models. cuBLAS deterministic mode triggers internal errors
  on Gemma-based models (pi0, pi05) during backward pass.
- Enable use_amp=true for pi0, pi0_fast, and pi05 to reduce memory
  footprint from fp32 (~16GB weights alone) to bf16, fitting within
  22GB GPU budget with room for activations and gradients.
- Small models (act, diffusion, multi_task_dit) still use deterministic
  mode for reproducible profiling results.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 15:34:04 +02:00
parent e16a95a78e
commit dbe01b0444
3 changed files with 9 additions and 5 deletions
+9 -3
View File
@@ -6,7 +6,8 @@
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--batch_size=4"
"--batch_size=4",
"--cudnn_deterministic=true"
]
},
"diffusion": {
@@ -16,7 +17,8 @@
"--dataset.episodes=[0]",
"--policy.type=diffusion",
"--policy.device=cuda",
"--batch_size=4"
"--batch_size=4",
"--cudnn_deterministic=true"
]
},
"groot": {
@@ -45,7 +47,8 @@
"--policy.device=cuda",
"--policy.horizon=32",
"--policy.n_action_steps=30",
"--batch_size=4"
"--batch_size=4",
"--cudnn_deterministic=true"
]
},
"pi0": {
@@ -56,6 +59,7 @@
"--policy.path=lerobot/pi0_base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--policy.use_amp=true",
"--batch_size=1",
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
]
@@ -68,6 +72,7 @@
"--policy.path=lerobot/pi0fast-base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--policy.use_amp=true",
"--batch_size=1",
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
]
@@ -80,6 +85,7 @@
"--policy.path=lerobot/pi05_base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--policy.use_amp=true",
"--batch_size=1",
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}",
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"