diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py index c24af6138..3b9cc8834 100644 --- a/tests/training/test_multi_gpu.py +++ b/tests/training/test_multi_gpu.py @@ -218,5 +218,5 @@ class TestMultiGPUTraining: assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}" # Verify optimizer state exists - optimizer_state = training_state_dir / "optimizer_state.pt" + optimizer_state = training_state_dir / "optimizer_state.safetensors" assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"