fix(stats): handle scalar stats robustly

- Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays
from scalar stats causing shape mismatches downstream.
This commit is contained in:
Khalil Meftah
2026-06-15 12:28:18 +02:00
parent 949f4fcbe9
commit 0efa3dc874
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)