fix(profiling): fix pi0 cuBLAS error and pi05 OOM on 22GB GPU

- Move cudnn_deterministic to per-spec train_args instead of hardcoding
  it for all models. cuBLAS deterministic mode triggers internal errors
  on Gemma-based models (pi0, pi05) during backward pass.
- Enable use_amp=true for pi0, pi0_fast, and pi05 to reduce memory
  footprint from fp32 (~16GB weights alone) to bf16, fitting within
  22GB GPU budget with room for activations and gradients.
- Small models (act, diffusion, multi_task_dit) still use deterministic
  mode for reproducible profiling results.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 15:34:04 +02:00
parent e16a95a78e
commit dbe01b0444
3 changed files with 9 additions and 5 deletions
-1
View File
@@ -103,7 +103,6 @@ def test_build_train_command_includes_profiling_outputs(tmp_path):
assert any(arg.startswith("--profile_output_dir=") for arg in cmd)
assert "--profile_mode=trace" in cmd
assert "--eval_freq=0" in cmd
assert "--cudnn_deterministic=true" in cmd
def test_build_artifact_index_collects_tables_and_traces(tmp_path):