fix(profiling): fix pi0 cuBLAS error and pi05 OOM on 22GB GPU

- Move cudnn_deterministic to per-spec train_args instead of hardcoding
  it for all models. cuBLAS deterministic mode triggers internal errors
  on Gemma-based models (pi0, pi05) during backward pass.
- Enable use_amp=true for pi0, pi0_fast, and pi05 to reduce memory
  footprint from fp32 (~16GB weights alone) to bf16, fitting within
  22GB GPU budget with room for activations and gradients.
- Small models (act, diffusion, multi_task_dit) still use deterministic
  mode for reproducible profiling results.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 15:34:04 +02:00
parent e16a95a78e
commit dbe01b0444
3 changed files with 9 additions and 5 deletions
-1
View File
@@ -160,7 +160,6 @@ def build_train_command(spec: ProfilingSpec, run_dir: Path, profile_mode: str) -
"--policy.push_to_hub=false",
"--num_workers=0",
"--log_freq=1",
"--cudnn_deterministic=true",
f"--profile_mode={profile_mode}",
f"--profile_output_dir={profile_output_dir}",
]